# HG changeset patch # User Tuomo Valkonen # Date 1772123923 18000 # Node ID 4f468d35fa29d3152b6f932ea82afc208bf94b92 # Parent 9738b51d90d7481bc3291fa0535a799eb7cd0951 General forward operators, separation of measures into own crate, and other architecture improvements to support the pointsource_pde crate. diff -r 9738b51d90d7 -r 4f468d35fa29 .cargo/config.toml --- a/.cargo/config.toml Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,2 +0,0 @@ -[target.'cfg(all(target_os = "macos"))'] -rustflags = ["-L", "/opt/homebrew/include"] diff -r 9738b51d90d7 -r 4f468d35fa29 .gitignore --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/.gitignore Thu Feb 26 11:38:43 2026 -0500 @@ -0,0 +1,1 @@ +.hgignore \ No newline at end of file diff -r 9738b51d90d7 -r 4f468d35fa29 .hgignore --- a/.hgignore Sun Apr 27 15:03:51 2025 -0500 +++ b/.hgignore Thu Feb 26 11:38:43 2026 -0500 @@ -1,6 +1,9 @@ -^target/ -^debug_out/ -^pointsource.._.*\.txt +syntax:glob +out/ +test/ +target/ +debug_out/ +**/pointsource??_*.txt flamegraph.svg DEADJOE -.*\.orig +**/*.orig diff -r 9738b51d90d7 -r 4f468d35fa29 Cargo.lock --- a/Cargo.lock Sun Apr 27 15:03:51 2025 -0500 +++ b/Cargo.lock Thu Feb 26 11:38:43 2026 -0500 @@ -45,6 +45,7 @@ "num-traits", "numeric_literals", "rayon", + "rustc_version", "serde", "serde_json", "simba", @@ -379,9 +380,7 @@ [[package]] name = "float_extras" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b22b70f8649ea2315955f1a36d964b0e4da482dfaa5f0d04df0d1fb7c338ab7a" +version = "0.1.7" dependencies = [ "libc", ] @@ -394,16 +393,113 @@ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", - "wasi", + "r-efi", + "wasip2", ] [[package]] +name = "glam" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "333928d5eb103c5d4050533cec0384302db6be8ef7d3cebd30ec6a35350353da" + +[[package]] +name = "glam" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3abb554f8ee44336b72d522e0a7fe86a29e09f839a36022fa869a7dfe941a54b" + +[[package]] +name = "glam" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4126c0479ccf7e8664c36a2d719f5f2c140fbb4f9090008098d2c291fa5b3f16" + +[[package]] +name = "glam" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01732b97afd8508eee3333a541b9f7610f454bb818669e66e90f5f57c93a776" + +[[package]] +name = "glam" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525a3e490ba77b8e326fb67d4b44b4bd2f920f44d4cc73ccec50adc68e3bee34" + +[[package]] +name = "glam" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b8509e6791516e81c1a630d0bd7fbac36d2fa8712a9da8662e716b52d5051ca" + +[[package]] +name = "glam" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43e957e744be03f5801a55472f593d43fabdebf25a4585db250f04d86b1675f" + +[[package]] +name = "glam" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518faa5064866338b013ff9b2350dc318e14cc4fcd6cb8206d7e7c9886c98815" + +[[package]] +name = "glam" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774" + +[[package]] +name = "glam" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e4afd9ad95555081e109fe1d21f2a30c691b5f0919c67dfa690a2e1eb6bd51c" + +[[package]] +name = "glam" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5418c17512bdf42730f9032c74e1ae39afc408745ebb2acf72fbc4691c17945" + +[[package]] +name = "glam" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151665d9be52f9bb40fc7966565d39666f2d1e69233571b71b87791c7e0528b3" + +[[package]] +name = "glam" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e05e7e6723e3455f4818c7b26e855439f7546cf617ef669d1adedb8669e5cb9" + +[[package]] +name = "glam" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "779ae4bf7e8421cf91c0b3b64e7e8b40b862fba4d393f59150042de7c4965a94" + +[[package]] +name = "glam" +version = "0.29.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8babf46d4c1c9d92deac9f7be466f76dfc4482b6452fc5024b5e8daf6ffeb3ee" + +[[package]] +name = "glam" +version = "0.30.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd47b05dddf0005d850e5644cae7f2b14ac3df487979dbfff3b56f20b1a6ae46" + +[[package]] name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -517,9 +613,9 @@ [[package]] name = "libc" -version = "0.2.149" +version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" +checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" [[package]] name = "libm" @@ -550,6 +646,17 @@ ] [[package]] +name = "measures" +version = "0.1.0" +dependencies = [ + "alg_tools", + "nalgebra", + "numeric_literals", + "regex", + "serde", +] + +[[package]] name = "memchr" version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -557,11 +664,27 @@ [[package]] name = "nalgebra" -version = "0.33.2" +version = "0.34.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b" +checksum = "c4d5b3eff5cd580f93da45e64715e8c20a3996342f1e466599cf7a267a0c2f5f" dependencies = [ "approx", + "glam 0.14.0", + "glam 0.15.2", + "glam 0.16.0", + "glam 0.17.3", + "glam 0.18.0", + "glam 0.19.0", + "glam 0.20.5", + "glam 0.21.3", + "glam 0.22.0", + "glam 0.23.0", + "glam 0.24.2", + "glam 0.25.0", + "glam 0.27.0", + "glam 0.28.0", + "glam 0.29.3", + "glam 0.30.9", "matrixmultiply", "nalgebra-macros", "num-complex", @@ -574,9 +697,9 @@ [[package]] name = "nalgebra-macros" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +checksum = "973e7178a678cfd059ccec50887658d482ce16b0aa9da3888ddeab5cd5eb4889" dependencies = [ "proc-macro2", "quote", @@ -676,9 +799,9 @@ [[package]] name = "once_cell" -version = "1.20.2" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "paste" @@ -688,9 +811,9 @@ [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "pointsource_algs" @@ -705,6 +828,7 @@ "cpu-time", "float_extras", "itertools", + "measures", "nalgebra", "num-traits", "numeric_literals", @@ -714,6 +838,7 @@ "serde", "serde_json", "serde_with", + "thiserror", ] [[package]] @@ -747,21 +872,26 @@ ] [[package]] -name = "rand" -version = "0.8.5" +name = "r-efi" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "libc", "rand_chacha", "rand_core", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", "rand_core", @@ -769,18 +899,18 @@ [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ "getrandom", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", "rand", @@ -842,6 +972,15 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] name = "rustix" version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -870,6 +1009,12 @@ ] [[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + +[[package]] name = "serde" version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -988,6 +1133,26 @@ ] [[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.93", +] + +[[package]] name = "time" version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1058,10 +1223,13 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] [[package]] name = "wasm-bindgen" @@ -1296,3 +1464,9 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" diff -r 9738b51d90d7 -r 4f468d35fa29 Cargo.toml --- a/Cargo.toml Sun Apr 27 15:03:51 2025 -0500 +++ b/Cargo.toml Thu Feb 26 11:38:43 2026 -0500 @@ -17,36 +17,52 @@ "pdps", "fista", "frank-wolfe", - "conditional gradient" + "conditional gradient", ] categories = ["mathematics", "science", "computer-vision"] [dependencies.alg_tools] version = "~0.4.0-dev" path = "../alg_tools" -default-features = false +default-features = false features = ["nightly"] +[dependencies.measures] +version = "~0.1.0" +path = "../measures" + [dependencies] serde = { version = "1.0", features = ["derive"] } num-traits = { version = "~0.2.14", features = ["std"] } -rand = "~0.8.5" +rand = "~0.9.2" colored = "~2.1.0" -rand_distr = "~0.4.3" -nalgebra = { version = "~0.33.0", features = ["rand-no-std"] } +rand_distr = "~0.5.1" +nalgebra = { version = "~0.34.0", features = ["rand-no-std"] } itertools = "~0.13.0" numeric_literals = "~0.2.0" GSL = "~7.0.0" -float_extras = "~0.1.6" +float_extras = { path = "../float_extras"} clap = { version = "~4.5.0", features = ["derive", "unicode", "wrap_help"] } -cpu-time = "~1.0.0" +cpu-time = "1.0.0" serde_json = "~1.0.85" chrono = { version = "~0.4.23", features = ["alloc", "std", "serde"] } anyhow = "1.0.95" serde_with = { version = "3.11.0", features = ["macros"] } +thiserror = "2.0.12" + +[features] +default = [] [build-dependencies] regex = "~1.11.0" [profile.release] debug = true + +[lib] +name = "pointsource_algs" +path = "src/lib.rs" + +[[bin]] +name = "pointsource_experiments" +path = "src/main.rs" diff -r 9738b51d90d7 -r 4f468d35fa29 README.md --- a/README.md Sun Apr 27 15:03:51 2025 -0500 +++ b/README.md Thu Feb 26 11:38:43 2026 -0500 @@ -36,8 +36,9 @@ brew install gsl ``` For other operating systems, suggestions are available in the [rust-GSL] - crate documentation. You may need to pass extra `RUSTFLAGS` options to - Cargo in the following steps to locate the library. + crate documentation. If not correctly installed, you may need to pass + extra `RUSTFLAGS` options to Cargo in the following steps to locate the + library. 4. Download [alg_tools] and unpack it under the same directory as this package. @@ -61,7 +62,7 @@ When doing this for the first time, several dependencies will be downloaded. Now you can run the default set of experiments with ``` -pointsource_algs -o results +pointsource_experiments -o results ``` The `-o results` option tells `pointsource_algs` to write results in the `results` directory. The option is required. @@ -71,7 +72,7 @@ cargo run --release -- -o results ``` The double-dash separates the options for the Cargo build system -and `pointsource_algs`. +and `pointsource_experiments`. ### Documentation diff -r 9738b51d90d7 -r 4f468d35fa29 build.rs --- a/build.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/build.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,20 +1,39 @@ +use regex::{Captures, Regex}; use std::env; -use regex::{Regex, Captures}; -fn proc>(re : &str, str : A) -> String { - let need_to_escape = Regex::new(r"([_*\\])").unwrap(); - Regex::new(re).unwrap().replace_all(str.as_ref(), |caps : &Captures| { - format!("{}{}{}", - caps.get(1).unwrap().as_str(), - need_to_escape.replace_all(caps.get(2).unwrap().as_str(), "\\$1"), - caps.get(3).unwrap().as_str() - ) - }).to_string() +fn main() { + process_readme(); + // Does not seem to be needed now. + //discover_gsl(); } -fn main() { +/* +/// Discover how to link to gsl, as the gsl crate does not provide this information +fn discover_gsl() { + pkg_config::Config::new().probe("gsl").unwrap(); +} +*/ + +/// `\`-escape `_`, `*`, and ´\\` in matches of `re` within `str`. +fn proc>(re: &str, str: A) -> String { + let need_to_escape = Regex::new(r"([_*\\])").unwrap(); + Regex::new(re) + .unwrap() + .replace_all(str.as_ref(), |caps: &Captures| { + format!( + "{}{}{}", + caps.get(1).unwrap().as_str(), + need_to_escape.replace_all(caps.get(2).unwrap().as_str(), "\\$1"), + caps.get(3).unwrap().as_str() + ) + }) + .to_string() +} + +/// Process the README for inclusion in documentation +fn process_readme() { let out_dir = env::var("OUT_DIR").unwrap(); - + // Since rust is stuck in 80's 7-bit gringo ASCII world, so that rustdoc does not support // markdown KaTeX mathematics, we have to process the README to include horrible horrible // horrible escapes for the math, and then use an vomit-inducingly ugly javasccript @@ -22,12 +41,13 @@ println!("cargo:rerun-if-changed=README.md"); - let readme = std::fs::read_to_string("README.md") - .expect("Error reading README"); + let readme = std::fs::read_to_string("README.md").expect("Error reading README"); // Escape _, *, and \ in equations. - let readme_uglified = proc(r"(?m)([^$]\$)([^$]+)(\$[^$])", - proc(r"([^$]\$\$)([^$]+)(\$\$[^$])", readme)); + let readme_uglified = proc( + r"(?m)([^$]\$)([^$]+)(\$[^$])", + proc(r"([^$]\$\$)([^$]+)(\$\$[^$])", readme), + ); // Remove the instructions for building the documentation let readme_cut = Regex::new("## Internals(.*|\n)*") .unwrap() diff -r 9738b51d90d7 -r 4f468d35fa29 rustfmt.toml --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/rustfmt.toml Thu Feb 26 11:38:43 2026 -0500 @@ -0,0 +1,3 @@ +overflow_delimited_expr = true +struct_lit_width = 80 + diff -r 9738b51d90d7 -r 4f468d35fa29 src/dataterm.rs --- a/src/dataterm.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/dataterm.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,30 +2,30 @@ Basid definitions for data terms */ -use numeric_literals::replace_float_literals; +//use numeric_literals::replace_float_literals; -use alg_tools::euclidean::Euclidean; -use alg_tools::linops::GEMV; -pub use alg_tools::norms::L1; -use alg_tools::norms::Norm; -use alg_tools::instance::{Instance, Space}; +use alg_tools::convex::Norm222; +//use alg_tools::euclidean::Euclidean; +//use alg_tools::instance::{Instance, Space}; +//use alg_tools::linops::GEMV; +use alg_tools::mapping::DataTerm; +use alg_tools::norms::{NormMapping, L1}; -use crate::types::*; -pub use crate::types::L2Squared; -use crate::measures::RNDM; +//use crate::types::*; +/* /// Calculates the residual $Aμ-b$. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn calculate_residual< - X : Space, - I : Instance, - F : Float, - V : Euclidean + Clone, - A : GEMV, + X: Space, + I: Instance, + F: Float, + V: Euclidean + Clone, + A: GEMV, >( - μ : I, - opA : &A, - b : &V + μ: I, + opA: &A, + b: &V, ) -> V { let mut r = b.clone(); opA.gemv(&mut r, 1.0, μ, -1.0); @@ -35,60 +35,24 @@ /// Calculates the residual $A(μ+μ_delta)-b$. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn calculate_residual2< - F : Float, - X : Space, - I : Instance, - J : Instance, - V : Euclidean + Clone, - A : GEMV, + F: Float, + X: Space, + I: Instance, + J: Instance, + V: Euclidean + Clone, + A: GEMV, >( - μ : I, - μ_delta : J, - opA : &A, - b : &V + μ: I, + μ_delta: J, + opA: &A, + b: &V, ) -> V { let mut r = b.clone(); opA.gemv(&mut r, 1.0, μ, -1.0); opA.gemv(&mut r, 1.0, μ_delta, 1.0); r } - - -/// Trait for data terms -#[replace_float_literals(F::cast_from(literal))] -pub trait DataTerm { - /// Calculates $F(y)$, where $F$ is the data fidelity. - fn calculate_fit(&self, _residual : &V) -> F; +*/ - /// Calculates $F(Aμ-b)$, where $F$ is the data fidelity. - fn calculate_fit_op, Codomain = V>>( - &self, - μ : I, - opA : &A, - b : &V - ) -> F - where - V : Euclidean + Clone, - I : Instance>, - { - let r = calculate_residual(μ, opA, b); - self.calculate_fit(&r) - } -} - -impl, const N : usize> -DataTerm -for L2Squared { - fn calculate_fit(&self, residual : &V) -> F { - residual.norm2_squared_div2() - } -} - - -impl + Norm, const N : usize> -DataTerm -for L1 { - fn calculate_fit(&self, residual : &V) -> F { - residual.norm(L1) - } -} +pub type L1DataTerm = DataTerm>; +pub type QuadraticDataTerm = DataTerm>; diff -r 9738b51d90d7 -r 4f468d35fa29 src/experiments.rs --- a/src/experiments.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/experiments.rs Thu Feb 26 11:38:43 2026 -0500 @@ -3,37 +3,30 @@ */ //use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use clap::ValueEnum; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::collections::hash_map::DefaultHasher; - +use crate::kernels::SupportProductFirst as Prod; +use crate::kernels::*; +use crate::run::{DefaultAlgorithm, ExperimentBiased, ExperimentV2, Named, RunnableExperiment}; +use crate::types::*; +use crate::{AlgorithmOverrides, ExperimentSetup}; use alg_tools::bisection_tree::*; use alg_tools::error::DynResult; use alg_tools::norms::Linfinity; - -use crate::{ExperimentOverrides, AlgorithmOverrides}; -use crate::kernels::*; -use crate::kernels::SupportProductFirst as Prod; -use crate::types::*; -use crate::run::{ - RunnableExperiment, - ExperimentV2, - ExperimentBiased, - Named, - DefaultAlgorithm, -}; -//use crate::fb::FBGenericConfig; -use crate::rand_distr::{SerializableNormal, SaltAndPepper}; +use clap::{Parser, ValueEnum}; +use serde::{Deserialize, Serialize}; +use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +//use crate::fb::InsertionConfig; +use crate::rand_distr::{SaltAndPepper, SerializableNormal}; use crate::regularisation::Regularisation; use alg_tools::euclidean::Euclidean; use alg_tools::instance::Instance; use alg_tools::mapping::Mapping; use alg_tools::operator_arithmetic::{MappingSum, Weighted}; +use itertools::Itertools; +use serde_with::skip_serializing_none; /// Experiments shorthands, to be used with the command line parser - #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[allow(non_camel_case_types)] pub enum DefaultExperiment { @@ -69,49 +62,77 @@ Experiment2D_TV_Fast, } +/// Command line experiment setup overrides +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone, Hash)] +pub struct DefaultExperimentSetup { + /// List of experiments to perform + #[arg(value_name = "EXPERIMENT")] + experiments: Vec, + + #[arg(long)] + /// Regularisation parameter override. + /// + /// Only use if running just a single experiment, as different experiments have different + /// regularisation parameters. + alpha: Option, + + #[arg(long)] + /// Gaussian noise variance override + variance: Option, + + #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] + /// Salt and pepper noise override. + salt_and_pepper: Option>, + + #[arg(long)] + /// Noise seed + noise_seed: Option, +} + macro_rules! make_float_constant { ($name:ident = $value:expr) => { #[derive(Debug, Copy, Eq, PartialEq, Clone, Serialize, Deserialize)] #[serde(into = "float")] struct $name; - impl Into for $name { + impl Into for $name { #[inline] - fn into(self) -> float { $value } + fn into(self) -> float { + $value + } } impl Constant for $name { type Type = float; - fn value(&self) -> float { $value } + fn value(&self) -> float { + $value + } } - } + }; } /// Ground-truth measure spike locations and magnitudes for 1D experiments -static MU_TRUE_1D_BASIC : [(float, float); 4] = [ - (0.10, 10.0), - (0.30, 2.0), - (0.70, 3.0), - (0.80, 5.0) -]; +static MU_TRUE_1D_BASIC: [(float, float); 4] = + [(0.10, 10.0), (0.30, 2.0), (0.70, 3.0), (0.80, 5.0)]; /// Ground-truth measure spike locations and magnitudes for 2D experiments -static MU_TRUE_2D_BASIC : [([float; 2], float); 4] = [ +static MU_TRUE_2D_BASIC: [([float; 2], float); 4] = [ ([0.15, 0.15], 10.0), ([0.75, 0.45], 2.0), ([0.80, 0.50], 4.0), - ([0.30, 0.70], 5.0) + ([0.30, 0.70], 5.0), ]; /// The $\{0,1\}$-valued characteristic function of a ball as a [`Mapping`]. -#[derive(Debug,Copy,Clone,Serialize,PartialEq)] -struct BallCharacteristic { - pub center : Loc, - pub radius : F, +#[derive(Debug, Copy, Clone, Serialize, PartialEq)] +struct BallCharacteristic { + pub center: Loc, + pub radius: F, } -impl Mapping> for BallCharacteristic { - type Codomain =F; +impl Mapping> for BallCharacteristic { + type Codomain = F; - fn apply>>(&self, i : I) -> F { + fn apply>>(&self, i: I) -> F { if self.center.dist2(i) <= self.radius { F::ONE } else { @@ -120,25 +141,52 @@ } } +/// Trait for customising the experiments available from the command line +impl ExperimentSetup for DefaultExperimentSetup { + type FloatType = f64; + + fn runnables(&self) -> DynResult>>> { + self.experiments + .iter() + .unique() + .map(|e| e.get_experiment(self)) + .try_collect() + } +} + //#[replace_float_literals(F::cast_from(literal))] impl DefaultExperiment { + // fn default_list() -> Vec { + // use DefaultExperiment::*; + // [ + // Experiment1D, + // Experiment1DFast, + // Experiment2D, + // Experiment2DFast, + // Experiment1D_L1, + // ] + // .into() + // } + /// Convert the experiment shorthand into a runnable experiment configuration. - pub fn get_experiment(&self, cli : &ExperimentOverrides) -> DynResult>> { - let name = "pointsource".to_string() - + self.to_possible_value().unwrap().get_name(); + fn get_experiment( + &self, + cli: &DefaultExperimentSetup, + ) -> DynResult>> { + let name = "pointsource".to_string() + self.to_possible_value().unwrap().get_name(); let kernel_plot_width = 0.2; - const BASE_SEED : u64 = 915373234; + const BASE_SEED: u64 = 915373234; - const N_SENSORS_1D : usize = 100; - make_float_constant!(SensorWidth1D = 0.4/(N_SENSORS_1D as float)); + const N_SENSORS_1D: usize = 100; + make_float_constant!(SensorWidth1D = 0.4 / (N_SENSORS_1D as float)); - const N_SENSORS_2D : usize = 16; - make_float_constant!(SensorWidth2D = 0.4/(N_SENSORS_2D as float)); + const N_SENSORS_2D: usize = 16; + make_float_constant!(SensorWidth2D = 0.4 / (N_SENSORS_2D as float)); - const N_SENSORS_2D_MORE : usize = 32; - make_float_constant!(SensorWidth2DMore = 0.4/(N_SENSORS_2D_MORE as float)); + //const N_SENSORS_2D_MORE: usize = 32; + //make_float_constant!(SensorWidth2DMore = 0.4 / (N_SENSORS_2D_MORE as float)); make_float_constant!(Variance1 = 0.05.powi(2)); make_float_constant!(CutOff1 = 0.15); @@ -160,45 +208,43 @@ // .. Default::default() // } // ); - let sliding_fb_cut_gaussian = (DefaultAlgorithm::SlidingFB, - AlgorithmOverrides { - theta0 : Some(0.3), - .. Default::default() - } - ); + let sliding_fb_cut_gaussian = (DefaultAlgorithm::SlidingFB, AlgorithmOverrides { + theta0: Some(0.3), + ..Default::default() + }); // let higher_cpos = |alg| (alg, // AlgorithmOverrides { // transport_tolerance_pos : Some(1000.0), // .. Default::default() // } // ); - let higher_cpos_merging = |alg| (alg, - AlgorithmOverrides { - transport_tolerance_pos : Some(1000.0), - merge : Some(true), - fitness_merging : Some(true), - .. Default::default() - } - ); - let higher_cpos_merging_steptune = |alg| (alg, - AlgorithmOverrides { - transport_tolerance_pos : Some(1000.0), - theta0 : Some(0.3), - merge : Some(true), - fitness_merging : Some(true), - .. Default::default() - } - ); - let much_higher_cpos_merging_steptune = |alg| (alg, - AlgorithmOverrides { - transport_tolerance_pos : Some(10000.0), - sigma0 : Some(0.15), - theta0 : Some(0.3), - merge : Some(true), - fitness_merging : Some(true), - .. Default::default() - } - ); + let higher_cpos_merging = |alg| { + (alg, AlgorithmOverrides { + transport_tolerance_pos: Some(1000.0), + merge: Some(true), + fitness_merging: Some(true), + ..Default::default() + }) + }; + let higher_cpos_merging_steptune = |alg| { + (alg, AlgorithmOverrides { + transport_tolerance_pos: Some(1000.0), + theta0: Some(0.3), + merge: Some(true), + fitness_merging: Some(true), + ..Default::default() + }) + }; + let much_higher_cpos_merging_steptune = |alg| { + (alg, AlgorithmOverrides { + transport_tolerance_pos: Some(10000.0), + sigma0: Some(0.15), + theta0: Some(0.3), + merge: Some(true), + fitness_merging: Some(true), + ..Default::default() + }) + }; // We add a hash of the experiment name to the configured // noise seed to not use the same noise for different experiments. let mut h = DefaultHasher::new(); @@ -210,238 +256,284 @@ use DefaultExperiment::*; Ok(match self { Experiment1D => { - let base_spread = Gaussian { variance : Variance1 }; - let spread_cutoff = BallIndicator { r : CutOff1, exponent : Linfinity }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]].into(), - sensor_count : [N_SENSORS_1D], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.08)), - noise_distr : SerializableNormal::new(0.0, cli.variance.unwrap_or(0.2))?, - dataterm : DataTerm::L2Squared, - μ_hat : MU_TRUE_1D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth1D, exponent : Linfinity }, - spread : Prod(spread_cutoff, base_spread), - kernel : Prod(AutoConvolution(spread_cutoff), base_spread), - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::from([ - sliding_fb_cut_gaussian, - higher_cpos_merging(DefaultAlgorithm::RadonFB), - higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), - ]), - }}) - }, + let base_spread = Gaussian { variance: Variance1 }; + let spread_cutoff = BallIndicator { r: CutOff1, exponent: Linfinity }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]].into(), + sensor_count: [N_SENSORS_1D], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.08)), + noise_distr: SerializableNormal::new(0.0, cli.variance.unwrap_or(0.2))?, + dataterm: DataTermType::L222, + μ_hat: MU_TRUE_1D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth1D, exponent: Linfinity }, + spread: Prod(spread_cutoff, base_spread), + kernel: Prod(AutoConvolution(spread_cutoff), base_spread), + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::from([ + sliding_fb_cut_gaussian, + higher_cpos_merging(DefaultAlgorithm::RadonFB), + higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), + ]), + }, + }) + } Experiment1DFast => { - let base_spread = HatConv { radius : Hat1 }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]].into(), - sensor_count : [N_SENSORS_1D], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.06)), - noise_distr : SerializableNormal::new(0.0, cli.variance.unwrap_or(0.2))?, - dataterm : DataTerm::L2Squared, - μ_hat : MU_TRUE_1D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth1D, exponent : Linfinity }, - spread : base_spread, - kernel : base_spread, - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::from([ - higher_cpos_merging(DefaultAlgorithm::RadonFB), - higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), - ]), - }}) - }, + let base_spread = HatConv { radius: Hat1 }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]].into(), + sensor_count: [N_SENSORS_1D], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.06)), + noise_distr: SerializableNormal::new(0.0, cli.variance.unwrap_or(0.2))?, + dataterm: DataTermType::L222, + μ_hat: MU_TRUE_1D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth1D, exponent: Linfinity }, + spread: base_spread, + kernel: base_spread, + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::from([ + higher_cpos_merging(DefaultAlgorithm::RadonFB), + higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), + ]), + }, + }) + } Experiment2D => { - let base_spread = Gaussian { variance : Variance1 }; - let spread_cutoff = BallIndicator { r : CutOff1, exponent : Linfinity }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]; 2].into(), - sensor_count : [N_SENSORS_2D; 2], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.19)), - noise_distr : SerializableNormal::new(0.0, cli.variance.unwrap_or(0.25))?, - dataterm : DataTerm::L2Squared, - μ_hat : MU_TRUE_2D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth2D, exponent : Linfinity }, - spread : Prod(spread_cutoff, base_spread), - kernel : Prod(AutoConvolution(spread_cutoff), base_spread), - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::from([ - sliding_fb_cut_gaussian, - higher_cpos_merging(DefaultAlgorithm::RadonFB), - higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), - ]), - }}) - }, + let base_spread = Gaussian { variance: Variance1 }; + let spread_cutoff = BallIndicator { r: CutOff1, exponent: Linfinity }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]; 2].into(), + sensor_count: [N_SENSORS_2D; 2], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.19)), + noise_distr: SerializableNormal::new(0.0, cli.variance.unwrap_or(0.25))?, + dataterm: DataTermType::L222, + μ_hat: MU_TRUE_2D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth2D, exponent: Linfinity }, + spread: Prod(spread_cutoff, base_spread), + kernel: Prod(AutoConvolution(spread_cutoff), base_spread), + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::from([ + sliding_fb_cut_gaussian, + higher_cpos_merging(DefaultAlgorithm::RadonFB), + higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), + ]), + }, + }) + } Experiment2DFast => { - let base_spread = HatConv { radius : Hat1 }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]; 2].into(), - sensor_count : [N_SENSORS_2D; 2], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.12)), - noise_distr : SerializableNormal::new(0.0, cli.variance.unwrap_or(0.15))?, //0.25 - dataterm : DataTerm::L2Squared, - μ_hat : MU_TRUE_2D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth2D, exponent : Linfinity }, - spread : base_spread, - kernel : base_spread, - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::from([ - higher_cpos_merging(DefaultAlgorithm::RadonFB), - higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), - ]), - }}) - }, - Experiment1D_L1 => { - let base_spread = Gaussian { variance : Variance1 }; - let spread_cutoff = BallIndicator { r : CutOff1, exponent : Linfinity }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]].into(), - sensor_count : [N_SENSORS_1D], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.1)), - noise_distr : SaltAndPepper::new( - cli.salt_and_pepper.as_ref().map_or(0.6, |v| v[0]), - cli.salt_and_pepper.as_ref().map_or(0.4, |v| v[1]) - )?, - dataterm : DataTerm::L1, - μ_hat : MU_TRUE_1D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth1D, exponent : Linfinity }, - spread : Prod(spread_cutoff, base_spread), - kernel : Prod(AutoConvolution(spread_cutoff), base_spread), - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::new(), - }}) - }, - Experiment1D_L1_Fast => { - let base_spread = HatConv { radius : Hat1 }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]].into(), - sensor_count : [N_SENSORS_1D], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.12)), - noise_distr : SaltAndPepper::new( - cli.salt_and_pepper.as_ref().map_or(0.6, |v| v[0]), - cli.salt_and_pepper.as_ref().map_or(0.4, |v| v[1]) - )?, - dataterm : DataTerm::L1, - μ_hat : MU_TRUE_1D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth1D, exponent : Linfinity }, - spread : base_spread, - kernel : base_spread, - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::new(), - }}) - }, - Experiment2D_L1 => { - let base_spread = Gaussian { variance : Variance1 }; - let spread_cutoff = BallIndicator { r : CutOff1, exponent : Linfinity }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]; 2].into(), - sensor_count : [N_SENSORS_2D; 2], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.35)), - noise_distr : SaltAndPepper::new( - cli.salt_and_pepper.as_ref().map_or(0.8, |v| v[0]), - cli.salt_and_pepper.as_ref().map_or(0.2, |v| v[1]) - )?, - dataterm : DataTerm::L1, - μ_hat : MU_TRUE_2D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth2D, exponent : Linfinity }, - spread : Prod(spread_cutoff, base_spread), - kernel : Prod(AutoConvolution(spread_cutoff), base_spread), - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::from([ - ]), - }}) - }, - Experiment2D_L1_Fast => { - let base_spread = HatConv { radius : Hat1 }; - Box::new(Named { name, data : ExperimentV2 { - domain : [[0.0, 1.0]; 2].into(), - sensor_count : [N_SENSORS_2D; 2], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.40)), - noise_distr : SaltAndPepper::new( - cli.salt_and_pepper.as_ref().map_or(0.8, |v| v[0]), - cli.salt_and_pepper.as_ref().map_or(0.2, |v| v[1]) - )?, - dataterm : DataTerm::L1, - μ_hat : MU_TRUE_2D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth2D, exponent : Linfinity }, - spread : base_spread, - kernel : base_spread, - kernel_plot_width, - noise_seed, - default_merge_radius, - algorithm_overrides: HashMap::from([ - ]), - }}) - }, - Experiment1D_TV_Fast => { - let base_spread = HatConv { radius : HatBias }; - Box::new(Named { name, data : ExperimentBiased { - λ : 0.02, - bias : MappingSum::new([ - Weighted::new(1.0, BallCharacteristic{ center : 0.3.into(), radius : 0.2 }), - Weighted::new(0.5, BallCharacteristic{ center : 0.6.into(), radius : 0.3 }), - ]), - base : ExperimentV2 { - domain : [[0.0, 1.0]].into(), - sensor_count : [N_SENSORS_1D], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.2)), - noise_distr : SerializableNormal::new(0.0, cli.variance.unwrap_or(0.1))?, - dataterm : DataTerm::L2Squared, - μ_hat : MU_TRUE_1D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth1D, exponent : Linfinity }, - spread : base_spread, - kernel : base_spread, + let base_spread = HatConv { radius: Hat1 }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]; 2].into(), + sensor_count: [N_SENSORS_2D; 2], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.12)), + noise_distr: SerializableNormal::new(0.0, cli.variance.unwrap_or(0.15))?, //0.25 + dataterm: DataTermType::L222, + μ_hat: MU_TRUE_2D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth2D, exponent: Linfinity }, + spread: base_spread, + kernel: base_spread, kernel_plot_width, noise_seed, default_merge_radius, algorithm_overrides: HashMap::from([ - higher_cpos_merging_steptune(DefaultAlgorithm::RadonForwardPDPS), - higher_cpos_merging_steptune(DefaultAlgorithm::RadonSlidingPDPS), + higher_cpos_merging(DefaultAlgorithm::RadonFB), + higher_cpos_merging(DefaultAlgorithm::RadonSlidingFB), ]), }, - }}) - }, - Experiment2D_TV_Fast => { - let base_spread = HatConv { radius : Hat1 }; - Box::new(Named { name, data : ExperimentBiased { - λ : 0.005, - bias : MappingSum::new([ - Weighted::new(1.0, BallCharacteristic{ center : [0.3, 0.3].into(), radius : 0.2 }), - Weighted::new(0.5, BallCharacteristic{ center : [0.6, 0.6].into(), radius : 0.3 }), - ]), - base : ExperimentV2 { - domain : [[0.0, 1.0]; 2].into(), - sensor_count : [N_SENSORS_2D; 2], - regularisation : Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.06)), - noise_distr : SerializableNormal::new(0.0, cli.variance.unwrap_or(0.15))?, //0.25 - dataterm : DataTerm::L2Squared, - μ_hat : MU_TRUE_2D_BASIC.into(), - sensor : BallIndicator { r : SensorWidth2D, exponent : Linfinity }, - spread : base_spread, - kernel : base_spread, + }) + } + Experiment1D_L1 => { + let base_spread = Gaussian { variance: Variance1 }; + let spread_cutoff = BallIndicator { r: CutOff1, exponent: Linfinity }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]].into(), + sensor_count: [N_SENSORS_1D], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.1)), + noise_distr: SaltAndPepper::new( + cli.salt_and_pepper.as_ref().map_or(0.6, |v| v[0]), + cli.salt_and_pepper.as_ref().map_or(0.4, |v| v[1]), + )?, + dataterm: DataTermType::L1, + μ_hat: MU_TRUE_1D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth1D, exponent: Linfinity }, + spread: Prod(spread_cutoff, base_spread), + kernel: Prod(AutoConvolution(spread_cutoff), base_spread), + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::new(), + }, + }) + } + Experiment1D_L1_Fast => { + let base_spread = HatConv { radius: Hat1 }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]].into(), + sensor_count: [N_SENSORS_1D], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.12)), + noise_distr: SaltAndPepper::new( + cli.salt_and_pepper.as_ref().map_or(0.6, |v| v[0]), + cli.salt_and_pepper.as_ref().map_or(0.4, |v| v[1]), + )?, + dataterm: DataTermType::L1, + μ_hat: MU_TRUE_1D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth1D, exponent: Linfinity }, + spread: base_spread, + kernel: base_spread, + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::new(), + }, + }) + } + Experiment2D_L1 => { + let base_spread = Gaussian { variance: Variance1 }; + let spread_cutoff = BallIndicator { r: CutOff1, exponent: Linfinity }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]; 2].into(), + sensor_count: [N_SENSORS_2D; 2], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.35)), + noise_distr: SaltAndPepper::new( + cli.salt_and_pepper.as_ref().map_or(0.8, |v| v[0]), + cli.salt_and_pepper.as_ref().map_or(0.2, |v| v[1]), + )?, + dataterm: DataTermType::L1, + μ_hat: MU_TRUE_2D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth2D, exponent: Linfinity }, + spread: Prod(spread_cutoff, base_spread), + kernel: Prod(AutoConvolution(spread_cutoff), base_spread), kernel_plot_width, noise_seed, default_merge_radius, - algorithm_overrides: HashMap::from([ - much_higher_cpos_merging_steptune(DefaultAlgorithm::RadonForwardPDPS), - much_higher_cpos_merging_steptune(DefaultAlgorithm::RadonSlidingPDPS), + algorithm_overrides: HashMap::from([]), + }, + }) + } + Experiment2D_L1_Fast => { + let base_spread = HatConv { radius: Hat1 }; + Box::new(Named { + name, + data: ExperimentV2 { + domain: [[0.0, 1.0]; 2].into(), + sensor_count: [N_SENSORS_2D; 2], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.40)), + noise_distr: SaltAndPepper::new( + cli.salt_and_pepper.as_ref().map_or(0.8, |v| v[0]), + cli.salt_and_pepper.as_ref().map_or(0.2, |v| v[1]), + )?, + dataterm: DataTermType::L1, + μ_hat: MU_TRUE_2D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth2D, exponent: Linfinity }, + spread: base_spread, + kernel: base_spread, + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::from([]), + }, + }) + } + Experiment1D_TV_Fast => { + let base_spread = HatConv { radius: HatBias }; + Box::new(Named { + name, + data: ExperimentBiased { + λ: 0.02, + bias: MappingSum::new([ + Weighted::new(1.0, BallCharacteristic { + center: 0.3.into(), + radius: 0.2, + }), + Weighted::new(0.5, BallCharacteristic { + center: 0.6.into(), + radius: 0.3, + }), ]), + base: ExperimentV2 { + domain: [[0.0, 1.0]].into(), + sensor_count: [N_SENSORS_1D], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.2)), + noise_distr: SerializableNormal::new(0.0, cli.variance.unwrap_or(0.1))?, + dataterm: DataTermType::L222, + μ_hat: MU_TRUE_1D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth1D, exponent: Linfinity }, + spread: base_spread, + kernel: base_spread, + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::from([ + higher_cpos_merging_steptune(DefaultAlgorithm::RadonForwardPDPS), + higher_cpos_merging_steptune(DefaultAlgorithm::RadonSlidingPDPS), + ]), + }, }, - }}) - }, + }) + } + Experiment2D_TV_Fast => { + let base_spread = HatConv { radius: Hat1 }; + Box::new(Named { + name, + data: ExperimentBiased { + λ: 0.005, + bias: MappingSum::new([ + Weighted::new(1.0, BallCharacteristic { + center: [0.3, 0.3].into(), + radius: 0.2, + }), + Weighted::new(0.5, BallCharacteristic { + center: [0.6, 0.6].into(), + radius: 0.3, + }), + ]), + base: ExperimentV2 { + domain: [[0.0, 1.0]; 2].into(), + sensor_count: [N_SENSORS_2D; 2], + regularisation: Regularisation::NonnegRadon(cli.alpha.unwrap_or(0.06)), + noise_distr: SerializableNormal::new( + 0.0, + cli.variance.unwrap_or(0.15), + )?, //0.25 + dataterm: DataTermType::L222, + μ_hat: MU_TRUE_2D_BASIC.into(), + sensor: BallIndicator { r: SensorWidth2D, exponent: Linfinity }, + spread: base_spread, + kernel: base_spread, + kernel_plot_width, + noise_seed, + default_merge_radius, + algorithm_overrides: HashMap::from([ + much_higher_cpos_merging_steptune( + DefaultAlgorithm::RadonForwardPDPS, + ), + much_higher_cpos_merging_steptune( + DefaultAlgorithm::RadonSlidingPDPS, + ), + ]), + }, + }, + }) + } }) } } - diff -r 9738b51d90d7 -r 4f468d35fa29 src/fb.rs --- a/src/fb.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/fb.rs Thu Feb 26 11:38:43 2026 -0500 @@ -74,37 +74,34 @@

We solve this with either SSN or FB as determined by -[`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. +[`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`]. */ +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::Plotter; +pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; +use crate::regularisation::RegTerm; +use crate::types::*; +use alg_tools::error::DynResult; +use alg_tools::instance::Instance; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::mapping::DifferentiableMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; use colored::Colorize; use numeric_literals::replace_float_literals; use serde::{Deserialize, Serialize}; -use alg_tools::euclidean::Euclidean; -use alg_tools::instance::Instance; -use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::linops::{Mapping, GEMV}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; - -use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; -use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; -use crate::measures::merging::SpikeMerging; -use crate::measures::{DiscreteMeasure, RNDM}; -use crate::plot::{PlotLookup, Plotting, SeqPlotter}; -pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; -use crate::regularisation::RegTerm; -use crate::types::*; - /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FBConfig { /// Step length scaling pub τ0: F, + // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] + pub σp0: F, /// Generic parameters - pub generic: FBGenericConfig, + pub insertion: InsertionConfig, } #[replace_float_literals(F::cast_from(literal))] @@ -112,12 +109,13 @@ fn default() -> Self { FBConfig { τ0: 0.99, - generic: Default::default(), + σp0: 0.99, + insertion: Default::default(), } } } -pub(crate) fn prune_with_stats(μ: &mut RNDM) -> usize { +pub(crate) fn prune_with_stats(μ: &mut RNDM) -> usize { let n_before_prune = μ.len(); μ.prune(); debug_assert!(μ.len() <= n_before_prune); @@ -125,30 +123,19 @@ } #[replace_float_literals(F::cast_from(literal))] -pub(crate) fn postprocess< - F: Float, - V: Euclidean + Clone, - A: GEMV, Codomain = V>, - D: DataTerm, - const N: usize, ->( - mut μ: RNDM, - config: &FBGenericConfig, - dataterm: D, - opA: &A, - b: &V, -) -> RNDM +pub(crate) fn postprocess) -> F, const N: usize>( + mut μ: RNDM, + config: &InsertionConfig, + f: Dat, +) -> DynResult> where - RNDM: SpikeMerging, - for<'a> &'a RNDM: Instance>, + RNDM: SpikeMerging, + for<'a> &'a RNDM: Instance>, { - μ.merge_spikes_fitness( - config.final_merging_method(), - |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), - |&v| v, - ); + //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v); + μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v); μ.prune(); - μ + Ok(μ) } /// Iteratively solve the pointsource localisation problem using forward-backward splitting. @@ -161,50 +148,41 @@ /// /// For details on the mathematical formulation, see the [module level](self) documentation. /// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fb_reg( - opA: &A, - b: &A::Observable, - reg: Reg, +pub fn pointsource_fb_reg( + f: &Dat, + reg: &Reg, prox_penalty: &P, fbconfig: &FBConfig, iterator: I, - mut plotter: SeqPlotter, -) -> RNDM + mut plotter: Plot, + μ0 : Option>, +) -> DynResult> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory>, - for<'b> &'b A::Observable: std::ops::Neg, - A: ForwardModel, F> + AdjointProductBoundedBy, P, FloatType = F>, - A::PreadjointCodomain: RealMapping, - PlotLookup: Plotting, - RNDM: SpikeMerging, - Reg: RegTerm, - P: ProxPenalty, + I: AlgIteratorFactory>, + RNDM: SpikeMerging, + Dat: DifferentiableMapping, Codomain = F>, + Dat::DerivativeDomain: ClosedMul, + Reg: RegTerm, F>, + P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, + Plot: Plotter>, { // Set up parameters - let config = &fbconfig.generic; - let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); + let config = &fbconfig.insertion; + let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); // Statistics - let full_stats = |residual: &A::Observable, μ: &RNDM, ε, stats| IterInfo { - value: residual.norm2_squared_div2() + reg.apply(μ), + let full_stats = |μ: &RNDM, ε, stats| IterInfo { + value: f.apply(μ) + reg.apply(μ), n_spikes: μ.len(), ε, //postprocessing: config.postprocessing.then(|| μ.clone()), @@ -213,9 +191,10 @@ let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); + // TODO: optimise τ to be applied to residual. + let mut τv = f.differential(&μ) * τ; // Save current base point let μ_base = μ.clone(); @@ -223,7 +202,7 @@ // Insert and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, - ); + )?; // Prune and possibly merge spikes if config.merge_now(&state) { @@ -236,34 +215,27 @@ ε, config, ®, - Some(|μ̃: &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), + Some(|μ̃: &RNDM| f.apply(μ̃)), ); } stats.pruned += prune_with_stats(&mut μ); - // Update residual - residual = calculate_residual(&μ, opA, b); - let iter = state.iteration(); stats.this_iters += 1; // Give statistics if needed state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); - full_stats( - &residual, - &μ, - ε, - std::mem::replace(&mut stats, IterInfo::new()), - ) + full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - postprocess(μ, config, L2Squared, opA, b) + //postprocess(μ_prev, config, f) + postprocess(μ, config, |μ̃| f.apply(μ̃)) } /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. @@ -276,38 +248,30 @@ /// /// For details on the mathematical formulation, see the [module level](self) documentation. /// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fista_reg( - opA: &A, - b: &A::Observable, - reg: Reg, +pub fn pointsource_fista_reg( + f: &Dat, + reg: &Reg, prox_penalty: &P, fbconfig: &FBConfig, iterator: I, - mut plotter: SeqPlotter, -) -> RNDM + mut plotter: Plot, + μ0: Option> +) -> DynResult> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory>, - for<'b> &'b A::Observable: std::ops::Neg, - A: ForwardModel, F> + AdjointProductBoundedBy, P, FloatType = F>, - A::PreadjointCodomain: RealMapping, - PlotLookup: Plotting, - RNDM: SpikeMerging, - Reg: RegTerm, - P: ProxPenalty, + I: AlgIteratorFactory>, + RNDM: SpikeMerging, + Dat: DifferentiableMapping, Codomain = F>, + Dat::DerivativeDomain: ClosedMul, + Reg: RegTerm, F>, + P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, + Plot: Plotter>, { // Set up parameters - let config = &fbconfig.generic; - let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); + let config = &fbconfig.insertion; + let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -315,14 +279,13 @@ let mut ε = tolerance.initial(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut μ_prev = DiscreteMeasure::new(); - let mut residual = -b; + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); + let mut μ_prev = μ.clone(); let mut warned_merging = false; // Statistics - let full_stats = |ν: &RNDM, ε, stats| IterInfo { - value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), + let full_stats = |ν: &RNDM, ε, stats| IterInfo { + value: f.apply(ν) + reg.apply(ν), n_spikes: ν.len(), ε, // postprocessing: config.postprocessing.then(|| ν.clone()), @@ -333,7 +296,7 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); + let mut τv = f.differential(&μ) * τ; // Save current base point let μ_base = μ.clone(); @@ -341,7 +304,7 @@ // Insert new spikes and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, - ); + )?; // (Do not) merge spikes. if config.merge_now(&state) && !warned_merging { @@ -369,9 +332,6 @@ debug_assert!(μ.len() <= n_before_prune); stats.pruned += n_before_prune - μ.len(); - // Update residual - residual = calculate_residual(&μ, opA, b); - let iter = state.iteration(); stats.this_iters += 1; @@ -385,5 +345,6 @@ ε = tolerance.update(ε, iter); } - postprocess(μ_prev, config, L2Squared, opA, b) + //postprocess(μ_prev, config, f) + postprocess(μ_prev, config, |μ̃| f.apply(μ̃)) } diff -r 9738b51d90d7 -r 4f468d35fa29 src/forward_model.rs --- a/src/forward_model.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/forward_model.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,14 +2,15 @@ Forward models from discrete measures to observations. */ -use alg_tools::error::DynError; -use alg_tools::euclidean::Euclidean; -use alg_tools::instance::Instance; +use crate::dataterm::QuadraticDataTerm; +use crate::measures::{Radon, RNDM}; +use crate::types::*; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::euclidean::{ClosedEuclidean, Euclidean}; pub use alg_tools::linops::*; use alg_tools::norms::{Norm, NormExponent, L2}; +use serde::{Deserialize, Serialize}; -use crate::measures::Radon; -use crate::types::*; pub mod bias; pub mod sensor_grid; @@ -21,13 +22,12 @@ + GEMV + Preadjointable where - for<'a> Self::Observable: Instance, - Domain: Norm, + Domain: Norm, { /// The codomain or value space (of “observables”) for this operator. /// It is assumed to be a [`Euclidean`] space, and therefore also (identified with) /// the domain of the preadjoint. - type Observable: Euclidean + AXPY + Space + Clone; + type Observable: ClosedEuclidean + Clone; /// Write an observable into a file. fn write_observable(&self, b: &Self::Observable, prefix: String) -> DynError; @@ -36,48 +36,66 @@ fn zero_observable(&self) -> Self::Observable; } -/// Trait for operators $A$ for which $A_*A$ is bounded by some other operator. -pub trait AdjointProductBoundedBy: Linear { - type FloatType: Float; - /// Return $L$ such that $A_*A ≤ LD$. - fn adjoint_product_bound(&self, other: &D) -> Option; +/// Guess for [`BoundedCurvature`] calculations. +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +pub enum BoundedCurvatureGuess { + /// No iterate $μ^k$ is worse than $μ=0$. + BetterThanZero, } -/// Trait for operators $A$ for which $A_*A$ is bounded by a diagonal operator. -pub trait AdjointProductPairBoundedBy: Linear { - type FloatType: Float; - /// Return $(L, L_z)$ such that $A_*A ≤ (L_1 D_1, L_2 D_2)$. - fn adjoint_product_pair_bound( +/// Curvature error control: helper bounds for (4.2d), (5.2a), (5.2b), (5.15a), and (5.16a). +/// +/// Based on Lemma 5.11 and Example 5.12, the helper bound for (5.15a) and (5.16a) is (3.8). +/// Thus, subject to `guess` being correct, returns factor $(ℓ_F, Θ²)$ such that +/// $B_{F'(μ)} dγ ≤ ℓ_F c_2$ and $⟨F'(μ+Δ)-F'(μ)|Δ⟩ ≤ Θ²|γ|(c_2)$, where $Δ=(π_♯^1-π_♯^0)γ$. +/// +/// This trait is supposed to be implemented by the data term $F$, in the basic case a +/// [`Mapping`] from [`RNDM`] to a [`Float`] `F`. +/// The generic implementation for operators that satisfy [`BasicCurvatureBoundEstimates`] +/// uses Remark 5.15 and Example 5.16 for (4.2d) and (5.2a), and (5.2b); +/// and Lemma 3.8 for (3.8). +pub trait BoundedCurvature { + /// Returns $(ℓ_F, Θ²)$ or individual errors for each. + fn curvature_bound_components( &self, - other1: &D1, - other_2: &D2, - ) -> Option<(Self::FloatType, Self::FloatType)>; + guess: BoundedCurvatureGuess, + ) -> (DynResult, DynResult); } -/* -/// Trait for [`ForwardModel`]s whose preadjoint has Lipschitz values. -pub trait LipschitzValues { - type FloatType : Float; - /// Return (if one exists) a factor $L$ such that $A_*z$ is $L$-Lipschitz for all - /// $z$ in the unit ball. - fn value_unit_lipschitz_factor(&self) -> Option { - None - } +/// Curvature error control: helper bounds for (4.2d), (5.2a), (5.2b), (5.15a), and (5.16a) +/// for quadratic dataterms $F(μ) = \frac{1}{2}\|Aμ-b\|^2$. +/// +/// This trait is to be implemented by the [`Linear`] operator $A$, in the basic from +/// [`RNDM`] to a an Euclidean space. +/// It is used by implementations of [`BoundedCurvature`] for $F$. +/// +/// Based on Lemma 5.11 and Example 5.12, the helper bound for (5.15a) and (5.16a) is (3.8). +/// This trait provides the factor $θ²$ of (3.8) as determined by Lemma 3.8. +/// To aid in calculating (4.2d), (5.2a), (5.2b), motivated by Example 5.16, it also provides +/// $ℓ_F^0$ such that $∇v^k$ $ℓ_F^0 \|Aμ-b\|$-Lipschitz. Here $v^k := F'(∪^k)$. +pub trait BasicCurvatureBoundEstimates { + /// Returns $(ℓ_F^0, Θ²)$ or individual errors for each. + fn basic_curvature_bound_components(&self) -> (DynResult, DynResult); +} - /// Return (if one exists) a factor $L$ such that $∇A_*z$ is $L$-Lipschitz for all - /// $z$ in the unit ball. - fn value_diff_unit_lipschitz_factor(&self) -> Option { - None +impl BoundedCurvature for QuadraticDataTerm, A> +where + F: Float, + Z: Clone + Space + Euclidean, + A: Mapping, Codomain = Z>, + A: BasicCurvatureBoundEstimates, +{ + fn curvature_bound_components( + &self, + guess: BoundedCurvatureGuess, + ) -> (DynResult, DynResult) { + match guess { + BoundedCurvatureGuess::BetterThanZero => { + let opA = self.operator(); + let b = self.data(); + let (ℓ_F0, θ2) = opA.basic_curvature_bound_components(); + (ℓ_F0.map(|l| l * b.norm2()), θ2) + } + } } } -*/ - -/// Trait for [`ForwardModel`]s that satisfy bounds on curvature. -pub trait BoundedCurvature { - type FloatType: Float; - - /// Returns factor $ℓ_F$ and $ℓ_r$ such that - /// $B_{F'(μ)} dγ ≤ ℓ_F c_2$ and $⟨F'(μ)+F'(μ+Δ)|Δ⟩ ≤ ℓ_r|γ|(c_2)$, - /// where $Δ=(π_♯^1-π_♯^0)γ$. - fn curvature_bound_components(&self) -> (Option, Option); -} diff -r 9738b51d90d7 -r 4f468d35fa29 src/forward_model/bias.rs --- a/src/forward_model/bias.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/forward_model/bias.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,12 +2,18 @@ Simple parametric forward model. */ -use super::{AdjointProductBoundedBy, AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; -use crate::measures::RNDM; +use super::{BasicCurvatureBoundEstimates, BoundedCurvature, BoundedCurvatureGuess, ForwardModel}; +use crate::dataterm::QuadraticDataTerm; +use crate::measures::{Radon, RNDM}; +use crate::prox_penalty::{RadonSquared, StepLengthBoundPair}; +use crate::seminorms::DiscreteMeasureOp; use alg_tools::direct_product::Pair; -use alg_tools::error::DynError; -use alg_tools::linops::{IdOp, Linear, RowOp, ZeroOp, AXPY}; -use alg_tools::mapping::Space; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::euclidean::ClosedEuclidean; +use alg_tools::linops::{BoundedLinear, IdOp, RowOp, AXPY}; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Mapping, Space}; +use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::norms::{Norm, NormExponent, PairNorm, L2}; use alg_tools::types::{ClosedAdd, Float}; use numeric_literals::replace_float_literals; @@ -16,9 +22,9 @@ for RowOp> where E: NormExponent, - Domain: Space + Norm, + Domain: Space + Norm, F: Float, - A::Observable: ClosedAdd + Norm + 'static, + A::Observable: ClosedAdd + Norm + AXPY + 'static, A: ForwardModel + 'static, { type Observable = A::Observable; @@ -34,23 +40,46 @@ } #[replace_float_literals(F::cast_from(literal))] -impl AdjointProductPairBoundedBy, D, IdOp> - for RowOp> +impl<'a, F, A, 𝒟, Z, const N: usize> + StepLengthBoundPair, Z>, RowOp>>> + for Pair<&'a 𝒟, &'a IdOp> where - Domain: Space, - F: Float, - Z: Clone + Space + ClosedAdd, - A: AdjointProductBoundedBy, - A::Codomain: ClosedAdd, + RNDM: Space + for<'b> Norm<&'b 𝒟, F>, + F: Float + ToNalgebraRealField, + 𝒟: DiscreteMeasureOp, F>, + Z: Clone + ClosedEuclidean, + A: for<'b> BoundedLinear, &'b 𝒟, L2, F, Codomain = Z>, + for<'b> &'b 𝒟: NormExponent, { - type FloatType = F; + fn step_length_bound_pair( + &self, + f: &QuadraticDataTerm, Z>, RowOp>>, + ) -> DynResult<(F, F)> { + let l_0 = f.operator().0.opnorm_bound(self.0, L2)?.powi(2); + // [A_*; B_*][A, B] = [A_*A, A_* B; B_* A, B_* B] ≤ diag(2A_*A, 2B_*B) + // ≤ diag(2l_A𝒟_A, 2l_B𝒟_B), where now 𝒟_B=Id and l_B=1. + Ok((2.0 * l_0, 2.0)) + } +} - fn adjoint_product_pair_bound(&self, d: &D, _: &IdOp) -> Option<(F, F)> { - self.0.adjoint_product_bound(d).map(|l_0| { - // [A_*; B_*][A, B] = [A_*A, A_* B; B_* A, B_* B] ≤ diag(2A_*A, 2B_*B) - // ≤ diag(2l_A𝒟_A, 2l_B𝒟_B), where now 𝒟_B=Id and l_B=1. - (2.0 * l_0, 2.0) - }) +#[replace_float_literals(F::cast_from(literal))] +impl<'a, F, A, Z, const N: usize> + StepLengthBoundPair, Z>, RowOp>>> + for Pair<&'a RadonSquared, &'a IdOp> +where + RNDM: Space + Norm, + F: Float + ToNalgebraRealField, + Z: Clone + ClosedEuclidean, + A: BoundedLinear, Radon, L2, F, Codomain = Z>, +{ + fn step_length_bound_pair( + &self, + f: &QuadraticDataTerm, Z>, RowOp>>, + ) -> DynResult<(F, F)> { + let l_0 = f.operator().0.opnorm_bound(Radon, L2)?.powi(2); + // [A_*; B_*][A, B] = [A_*A, A_* B; B_* A, B_* B] ≤ diag(2A_*A, 2B_*B) + // ≤ diag(2l_A𝒟_A, 2l_B𝒟_B), where now 𝒟_B=Id and l_B=1. + Ok((2.0 * l_0, 2.0)) } } @@ -79,30 +108,44 @@ } */ -impl BoundedCurvature for RowOp> +use BoundedCurvatureGuess::*; + +/// Curvature error control: helper bounds for (4.2d), (5.2a), (5.2b), (5.15a), and (5.16a). +/// +/// Based on Lemma 5.11 and Example 5.12, the helper bound for (5.15a) and (5.16a) is (3.8). +/// Due to Example 6.1, defining $v^k$ as the projection $F'$ to the predual space of the +/// measures, returns, if possible, and subject to the guess being correct, factors $ℓ_F$ and +/// $Θ²$ such that $B_{P_ℳ^* F'(μ, z)} dγ ≤ ℓ_F c_2$ and +/// $⟨P_ℳ^*[F'(μ+Δ, z)-F'(μ, z)]|Δ⟩ ≤ Θ²|γ|(c_2)‖γ‖$, where $Δ=(π_♯^1-π_♯^0)γ$. +/// For our $F(μ, z)=\frac{1}{2}\|Aμ+z-b\|^2$, we have $F'(μ, z)=A\_*(Aμ+z-b)$, so +/// $F'(μ+Δ, z)-F'(μ, z)=A\_*AΔ$ is independent of $z$, and the bounding can be calculated +/// as in the case without $z$, based on Lemma 3.8. +/// +/// We use Remark 5.15 and Example 5.16 for (4.2d) and (5.2a) with the additional effect of $z$. +/// This is based on a Lipschitz estimate for $∇v^k$, where we still, similarly to the Example, +/// have $∇v^k(x)=∇A\_*(x)[Aμ^k+z^k-b]$. We estimate the final term similarly to the example, +/// assuming for the guess [`BetterThanZero`] that every iterate is better than $(μ, z)=0$. +/// This the final estimate is exactly as in the example, without $z$. +/// Thus we can directly use [`BasicCurvatureBoundEstimates`] on the operator $A$. +impl BoundedCurvature + for QuadraticDataTerm, Z>, RowOp>> where F: Float, - Z: Clone + Space + ClosedAdd, - A: BoundedCurvature, + Z: Clone + ClosedEuclidean, + A: Mapping, Codomain = Z>, + A: BasicCurvatureBoundEstimates, { - type FloatType = F; - - fn curvature_bound_components(&self) -> (Option, Option) { - self.0.curvature_bound_components() + fn curvature_bound_components( + &self, + guess: BoundedCurvatureGuess, + ) -> (DynResult, DynResult) { + match guess { + BetterThanZero => { + let opA = &self.operator().0; + let b = self.data(); + let (ℓ_F0, θ2) = opA.basic_curvature_bound_components(); + (ℓ_F0.map(|l| l * b.norm2()), θ2) + } + } } } - -#[replace_float_literals(F::cast_from(literal))] -impl<'a, F, D, XD, Y, const N: usize> AdjointProductBoundedBy, D> - for ZeroOp<'a, RNDM, XD, Y, F> -where - F: Float, - Y: AXPY + Clone, - D: Linear>, -{ - type FloatType = F; - /// Return $L$ such that $A_*A ≤ L𝒟$ is bounded by some `other` operator $𝒟$. - fn adjoint_product_bound(&self, _: &D) -> Option { - Some(0.0) - } -} diff -r 9738b51d90d7 -r 4f468d35fa29 src/forward_model/sensor_grid.rs --- a/src/forward_model/sensor_grid.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/forward_model/sensor_grid.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,13 +2,17 @@ Sensor grid forward model */ -use nalgebra::base::{DMatrix, DVector}; -use numeric_literals::replace_float_literals; -use std::iter::Zip; -use std::ops::RangeFrom; - +use super::{BasicCurvatureBoundEstimates, ForwardModel}; +use crate::frank_wolfe::FindimQuadraticModel; +use crate::kernels::{AutoConvolution, BoundedBy, Convolution}; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::preadjoint_helper::PreadjointHelper; +use crate::seminorms::{ConvolutionOp, SimpleConvolutionKernel}; +use crate::types::*; use alg_tools::bisection_tree::*; -use alg_tools::error::DynError; +use alg_tools::bounds::Bounded; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::euclidean::Euclidean; use alg_tools::instance::Instance; use alg_tools::iter::{MapX, Mappable}; use alg_tools::lingrid::*; @@ -18,79 +22,74 @@ use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::norms::{Linfinity, Norm, L1, L2}; use alg_tools::tabledump::write_csv; +use anyhow::anyhow; +use nalgebra::base::{DMatrix, DVector}; +use numeric_literals::replace_float_literals; +use std::iter::Zip; +use std::ops::RangeFrom; -use super::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel}; -use crate::frank_wolfe::FindimQuadraticModel; -use crate::kernels::{AutoConvolution, BoundedBy, Convolution}; -use crate::measures::{DiscreteMeasure, Radon}; -use crate::preadjoint_helper::PreadjointHelper; -use crate::seminorms::{ConvolutionOp, SimpleConvolutionKernel}; -use crate::types::*; - -type RNDM = DiscreteMeasure, F>; - -pub type ShiftedSensor = Shift, F, N>; +pub type ShiftedSensor = Shift, N, F>; /// Trait for physical convolution models. Has blanket implementation for all cases. -pub trait Spread: - 'static + Clone + Support + RealMapping + Bounded +pub trait Spread: + 'static + Clone + Support + RealMapping + Bounded { } -impl Spread for T +impl Spread for T where F: Float, - T: 'static + Clone + Support + Bounded + RealMapping, + T: 'static + Clone + Support + Bounded + RealMapping, { } /// Trait for compactly supported sensors. Has blanket implementation for all cases. -pub trait Sensor: - Spread + Norm + Norm +pub trait Sensor: + Spread + Norm + Norm { } -impl Sensor for T +impl Sensor for T where F: Float, - T: Spread + Norm + Norm, + T: Spread + Norm + Norm, { } pub trait SensorGridBT: - Clone + BTImpl> + Clone + BTImpl> where F: Float, - S: Sensor, - P: Spread, + S: Sensor, + P: Spread, { } impl SensorGridBT for T where - T: Clone + BTImpl>, + T: Clone + BTImpl>, F: Float, - S: Sensor, - P: Spread, + S: Sensor, + P: Spread, { } // We need type alias bounds to access associated types #[allow(type_alias_bounds)] pub type SensorGridBTFN, const N: usize> = - BTFN, BT, N>; + BTFN, BT, N>; /// Sensor grid forward model #[derive(Clone)] pub struct SensorGrid where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, BT: SensorGridBT, { - domain: Cube, + domain: Cube, sensor_count: [usize; N], sensor: S, spread: P, @@ -102,16 +101,16 @@ where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, + S: Sensor, + P: Spread, + Convolution: Spread + LocalAnalysis, { /// Create a new sensor grid. /// /// The parameter `depth` indicates the search depth of the created [`BT`]s /// for the adjoint values. pub fn new( - domain: Cube, + domain: Cube, sensor_count: [usize; N], sensor: S, spread: P, @@ -141,12 +140,12 @@ where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { /// Return the grid of sensor locations. - pub fn grid(&self) -> LinGrid { + pub fn grid(&self) -> LinGrid { lingrid_centered(&self.domain, &self.sensor_count) } @@ -157,7 +156,7 @@ /// Constructs a sensor shifted by `x`. #[inline] - fn shifted_sensor(&self, x: Loc) -> ShiftedSensor { + fn shifted_sensor(&self, x: Loc) -> ShiftedSensor { self.base_sensor.clone().shift(x) } @@ -180,44 +179,44 @@ } } -impl Mapping> for SensorGrid +impl Mapping> for SensorGrid where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { type Codomain = DVector; #[inline] - fn apply>>(&self, μ: I) -> DVector { + fn apply>>(&self, μ: I) -> DVector { let mut y = self._zero_observable(); self.apply_add(&mut y, μ); y } } -impl Linear> for SensorGrid +impl Linear> for SensorGrid where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { } #[replace_float_literals(F::cast_from(literal))] -impl GEMV, DVector> for SensorGrid +impl GEMV, DVector> for SensorGrid where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { - fn gemv>>(&self, y: &mut DVector, α: F, μ: I, β: F) { + fn gemv>>(&self, y: &mut DVector, α: F, μ: I, β: F) { let grid = self.grid(); if β == 0.0 { y.fill(0.0) @@ -227,38 +226,42 @@ if α == 1.0 { self.apply_add(y, μ) } else { - for δ in μ.ref_instance() { - for &d in self.bt.iter_at(&δ.x) { - let sensor = self.shifted_sensor(grid.entry_linear_unchecked(d)); - y[d] += sensor.apply(&δ.x) * (α * δ.α); + μ.eval_ref(|μr| { + for δ in μr { + for &d in self.bt.iter_at(&δ.x) { + let sensor = self.shifted_sensor(grid.entry_linear_unchecked(d)); + y[d] += sensor.apply(&δ.x) * (α * δ.α); + } } - } + }) } } - fn apply_add>>(&self, y: &mut DVector, μ: I) { + fn apply_add>>(&self, y: &mut DVector, μ: I) { let grid = self.grid(); - for δ in μ.ref_instance() { - for &d in self.bt.iter_at(&δ.x) { - let sensor = self.shifted_sensor(grid.entry_linear_unchecked(d)); - y[d] += sensor.apply(&δ.x) * δ.α; + μ.eval_ref(|μr| { + for δ in μr { + for &d in self.bt.iter_at(&δ.x) { + let sensor = self.shifted_sensor(grid.entry_linear_unchecked(d)); + y[d] += sensor.apply(&δ.x) * δ.α; + } } - } + }) } } -impl BoundedLinear, Radon, L2, F> +impl BoundedLinear, Radon, L2, F> for SensorGrid where F: Float, - BT: SensorGridBT>, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, + BT: SensorGridBT, + S: Sensor, + P: Spread, + Convolution: Spread, { /// An estimate on the operator norm in $𝕃(ℳ(Ω); ℝ^n)$ with $ℳ(Ω)$ equipped /// with the Radon norm, and $ℝ^n$ with the Euclidean norm. - fn opnorm_bound(&self, _: Radon, _: L2) -> F { + fn opnorm_bound(&self, _: Radon, _: L2) -> DynResult { // With {x_i}_{i=1}^n the grid centres and φ the kernel, we have // |Aμ|_2 = sup_{|z|_2 ≤ 1} ⟨z,Αμ⟩ = sup_{|z|_2 ≤ 1} ⟨A^*z|μ⟩ // ≤ sup_{|z|_2 ≤ 1} |A^*z|_∞ |μ|_ℳ @@ -271,22 +274,22 @@ // = |φ|_∞ √N_ψ |μ|_ℳ. // Hence let n = self.max_overlapping(); - self.base_sensor.bounds().uniform() * n.sqrt() + Ok(self.base_sensor.bounds().uniform() * n.sqrt()) } } -type SensorGridPreadjoint<'a, A, F, const N: usize> = PreadjointHelper<'a, A, RNDM>; +type SensorGridPreadjoint<'a, A, F, const N: usize> = PreadjointHelper<'a, A, RNDM>; -impl Preadjointable, DVector> +impl Preadjointable, DVector> for SensorGrid where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, + S: Sensor, + P: Spread, + Convolution: Spread + LocalAnalysis, { - type PreadjointCodomain = BTFN, BT, N>; + type PreadjointCodomain = BTFN, BT, N>; type Preadjoint<'a> = SensorGridPreadjoint<'a, Self, F, N> where @@ -303,10 +306,10 @@ for SensorGridPreadjoint<'a, SensorGrid, F, N> where F : Float, BT : SensorGridBT, - S : Sensor, - P : Spread, - Convolution : Spread + Lipschitz + DifferentiableMapping> + LocalAnalysis, - for<'b> as DifferentiableMapping>>::Differential<'b> : Lipschitz, + S : Sensor, + P : Spread, + Convolution : Spread + Lipschitz + DifferentiableMapping> + LocalAnalysis, + for<'b> as DifferentiableMapping>>::Differential<'b> : Lipschitz, { type FloatType = F; @@ -332,55 +335,49 @@ */ #[replace_float_literals(F::cast_from(literal))] -impl<'a, F, S, P, BT, const N: usize> BoundedCurvature for SensorGrid +impl<'a, F, S, P, BT, const N: usize> BasicCurvatureBoundEstimates for SensorGrid where - F: Float, + F: Float + ToNalgebraRealField, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + S: Sensor, + P: Spread, + DVector: Euclidean, + Convolution: Spread + Lipschitz - + DifferentiableMapping> + + DifferentiableMapping> + LocalAnalysis, - for<'b> as DifferentiableMapping>>::Differential<'b>: + for<'b> as DifferentiableMapping>>::Differential<'b>: Lipschitz, { - type FloatType = F; - - /// Returns factors $ℓ_F$ and $Θ²$ such that - /// $B_{F'(μ)} dγ ≤ ℓ_F c_2$ and $⟨F'(μ)+F'(μ+Δ)|Δ⟩ ≤ Θ²|γ|(c_2)‖γ‖$, - /// where $Δ=(π_♯^1-π_♯^0)γ$. - /// - /// See Lemma 3.8, Lemma 5.10, Remark 5.14, and Example 5.15. - fn curvature_bound_components(&self) -> (Option, Option) { + fn basic_curvature_bound_components(&self) -> (DynResult, DynResult) { let n_ψ = self.max_overlapping(); let ψ_diff_lip = self.base_sensor.diff_ref().lipschitz_factor(L2); let ψ_lip = self.base_sensor.lipschitz_factor(L2); - let ℓ_F = ψ_diff_lip.map(|l| (2.0 * n_ψ).sqrt() * l); - let θ2 = ψ_lip.map(|l| 4.0 * n_ψ * l.powi(2)); + let ℓ_F0 = ψ_diff_lip.map(|l| (2.0 * n_ψ).sqrt() * l); + let Θ2 = ψ_lip.map(|l| 4.0 * n_ψ * l.powi(2)); - (ℓ_F, θ2) + (ℓ_F0, Θ2) } } #[derive(Clone, Debug)] -pub struct SensorGridSupportGenerator +pub struct SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, + S: Sensor, + P: Spread, { base_sensor: Convolution, - grid: LinGrid, + grid: LinGrid, weights: DVector, } -impl SensorGridSupportGenerator +impl SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { #[inline] fn construct_sensor(&self, id: usize, w: F) -> Weighted, F> { @@ -397,12 +394,12 @@ } } -impl SupportGenerator for SensorGridSupportGenerator +impl SupportGenerator for SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { type Id = usize; type SupportType = Weighted, F>; @@ -434,14 +431,14 @@ } } -impl ForwardModel, F>, F> +impl ForwardModel, F>, F> for SensorGrid where F: Float + ToNalgebraRealField + nalgebra::RealField, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, + S: Sensor, + P: Spread, + Convolution: Spread + LocalAnalysis, { type Observable = DVector; @@ -456,17 +453,17 @@ } } -impl FindimQuadraticModel, F> for SensorGrid +impl FindimQuadraticModel, F> for SensorGrid where F: Float + ToNalgebraRealField + nalgebra::RealField, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, + S: Sensor, + P: Spread, + Convolution: Spread + LocalAnalysis, { fn findim_quadratic_model( &self, - μ: &DiscreteMeasure, F>, + μ: &DiscreteMeasure, F>, b: &Self::Observable, ) -> (DMatrix, DVector) { assert_eq!(b.len(), self.n_sensors()); @@ -483,29 +480,29 @@ } } -/// Implements the calculation a factor $L$ such that $A_*A ≤ L 𝒟$ for $A$ the forward model +/// Implements the calculation a factor $√L$ such that $A_*A ≤ L 𝒟$ for $A$ the forward model /// and $𝒟$ a seminorm of suitable form. /// /// **This assumes (but does not check) that the sensors are not overlapping.** #[replace_float_literals(F::cast_from(literal))] -impl AdjointProductBoundedBy, ConvolutionOp> - for SensorGrid +impl<'a, F, BT, S, P, K, const N: usize> + BoundedLinear, &'a ConvolutionOp, L2, F> for SensorGrid where - F: Float + nalgebra::RealField + ToNalgebraRealField, + F: Float + ToNalgebraRealField, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread, - K: SimpleConvolutionKernel, + S: Sensor, + P: Spread, + Convolution: Spread, + K: SimpleConvolutionKernel, AutoConvolution

: BoundedBy, + Weighted, F>: LocalAnalysis, { - type FloatType = F; - - fn adjoint_product_bound(&self, seminorm: &ConvolutionOp) -> Option { + fn opnorm_bound(&self, seminorm: &'a ConvolutionOp, _: L2) -> DynResult { // Sensors should not take on negative values to allow // A_*A to be upper bounded by a simple convolution of `spread`. + // TODO: Do we really need this restriction? if self.sensor.bounds().lower() < 0.0 { - return None; + return Err(anyhow!("Sensor not bounded from below by zero")); } // Calculate the factor $L_1$ for betwee $ℱ[ψ * ψ] ≤ L_1 ℱ[ρ]$ for $ψ$ the base spread @@ -517,33 +514,33 @@ let l0 = self.sensor.norm(Linfinity) * self.sensor.norm(L1); // The final transition factor is: - Some(l0 * l1) + Ok((l0 * l1).sqrt()) } } macro_rules! make_sensorgridsupportgenerator_scalarop_rhs { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { impl std::ops::$trait_assign - for SensorGridSupportGenerator + for SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { fn $fn_assign(&mut self, t: F) { self.weights.$fn_assign(t); } } - impl std::ops::$trait for SensorGridSupportGenerator + impl std::ops::$trait for SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { - type Output = SensorGridSupportGenerator; + type Output = SensorGridSupportGenerator; fn $fn(mut self, t: F) -> Self::Output { std::ops::$trait_assign::$fn_assign(&mut self.weights, t); self @@ -551,14 +548,14 @@ } impl<'a, F, S, P, const N: usize> std::ops::$trait - for &'a SensorGridSupportGenerator + for &'a SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { - type Output = SensorGridSupportGenerator; + type Output = SensorGridSupportGenerator; fn $fn(self, t: F) -> Self::Output { SensorGridSupportGenerator { base_sensor: self.base_sensor.clone(), @@ -575,14 +572,14 @@ macro_rules! make_sensorgridsupportgenerator_unaryop { ($trait:ident, $fn:ident) => { - impl std::ops::$trait for SensorGridSupportGenerator + impl std::ops::$trait for SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { - type Output = SensorGridSupportGenerator; + type Output = SensorGridSupportGenerator; fn $fn(mut self) -> Self::Output { self.weights = self.weights.$fn(); self @@ -590,14 +587,14 @@ } impl<'a, F, S, P, const N: usize> std::ops::$trait - for &'a SensorGridSupportGenerator + for &'a SensorGridSupportGenerator where F: Float, - S: Sensor, - P: Spread, - Convolution: Spread, + S: Sensor, + P: Spread, + Convolution: Spread, { - type Output = SensorGridSupportGenerator; + type Output = SensorGridSupportGenerator; fn $fn(self) -> Self::Output { SensorGridSupportGenerator { base_sensor: self.base_sensor.clone(), @@ -612,13 +609,13 @@ make_sensorgridsupportgenerator_unaryop!(Neg, neg); impl<'a, F, S, P, BT, const N: usize> Mapping> - for PreadjointHelper<'a, SensorGrid, RNDM> + for PreadjointHelper<'a, SensorGrid, RNDM> where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, N>, + S: Sensor, + P: Spread, + Convolution: Spread + LocalAnalysis, N>, { type Codomain = SensorGridBTFN; @@ -634,12 +631,12 @@ } impl<'a, F, S, P, BT, const N: usize> Linear> - for PreadjointHelper<'a, SensorGrid, RNDM> + for PreadjointHelper<'a, SensorGrid, RNDM> where F: Float, BT: SensorGridBT, - S: Sensor, - P: Spread, - Convolution: Spread + LocalAnalysis, N>, + S: Sensor, + P: Spread, + Convolution: Spread + LocalAnalysis, N>, { } diff -r 9738b51d90d7 -r 4f468d35fa29 src/forward_pdps.rs --- a/src/forward_pdps.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/forward_pdps.rs Thu Feb 26 11:38:43 2026 -0500 @@ -3,132 +3,158 @@ primal-dual proximal splitting with a forward step. */ -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; - +use crate::fb::*; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::Plotter; +use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; +use crate::regularisation::RegTerm; +use crate::types::*; +use alg_tools::convex::{Conjugable, Prox, Zero}; +use alg_tools::direct_product::Pair; +use alg_tools::error::DynResult; +use alg_tools::euclidean::ClosedEuclidean; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; -use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; -use alg_tools::norms::Norm; -use alg_tools::direct_product::Pair; +use alg_tools::linops::{BoundedLinear, IdOp, SimplyAdjointable, ZeroOp, AXPY, GEMV}; +use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::linops::{ - BoundedLinear, AXPY, GEMV, Adjointable, IdOp, -}; -use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, PairNorm}; - -use crate::types::*; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductPairBoundedBy, -}; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::fb::*; -use crate::regularisation::RegTerm; -use crate::dataterm::calculate_residual; +use alg_tools::norms::L2; +use anyhow::ensure; +use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; /// Settings for [`pointsource_forward_pdps_pair`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct ForwardPDPSConfig { - /// Primal step length scaling. - pub τ0 : F, - /// Primal step length scaling. - pub σp0 : F, - /// Dual step length scaling. - pub σd0 : F, +pub struct ForwardPDPSConfig { + /// Overall primal step length scaling. + pub τ0: F, + /// Primal step length scaling for additional variable. + pub σp0: F, + /// Dual step length scaling for additional variable. + /// + /// Taken zero for [`pointsource_fb_pair`]. + pub σd0: F, /// Generic parameters - pub insertion : FBGenericConfig, + pub insertion: InsertionConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for ForwardPDPSConfig { +impl Default for ForwardPDPSConfig { fn default() -> Self { - ForwardPDPSConfig { - τ0 : 0.99, - σd0 : 0.05, - σp0 : 0.99, - insertion : Default::default() - } + ForwardPDPSConfig { τ0: 0.99, σd0: 0.05, σp0: 0.99, insertion: Default::default() } } } -type MeasureZ = Pair, Z>; +type MeasureZ = Pair, Z>; /// Iteratively solve the pointsource localisation with an additional variable /// using primal-dual proximal splitting with a forward step. +/// +/// The problem is +/// $$ +/// \min_{μ, z}~ F(μ, z) + R(z) + H(K_z z) + Q(μ), +/// $$ +/// where +/// * The data term $F$ is given in `f`, +/// * the measure (Radon or positivity-constrained Radon) regulariser in $Q$ is given in `reg`, +/// * the functions $R$ and $H$ are given in `fnR` and `fnH`, and +/// * the operator $K_z$ in `opKz`. +/// +/// This is dualised to +/// $$ +/// \min_{μ, z}\max_y~ F(μ, z) + R(z) + ⟨K_z z, y⟩ + Q(μ) - H^*(y). +/// $$ +/// +/// The algorithm is controlled by: +/// * the proximal penalty in `prox_penalty`. +/// * the initial iterates in `z`, `y` +/// * The configuration in `config`. +/// * The `iterator` that controls stopping and reporting. +/// Moreover, plotting is performed by `plotter`. +/// +/// The step lengths need to satisfy +/// $$ +/// τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 +/// $$ ^^^^^^^^^^^^^^^^^^^^^^^^^ +/// with $1 > σ_p L_z$ and $1 > τ L$. +/// Since we are given “scalings” $τ_0$, $σ_{p,0}$, and $σ_{d,0}$ in `config`, we take +/// $σ_d=σ_{d,0}/‖K_z‖$, and $σ_p = σ_{p,0} / (L_z σ_d‖K_z‖)$. This satisfies the +/// part $[σ_p L_z + σ_pσ_d‖K_z‖^2] < 1$. Then with these cohices, we solve +/// $$ +/// τ = τ_0 \frac{1 - σ_{p,0}}{(σ_d M (1-σ_p L_z) + (1 - σ_{p,0} L)}. +/// $$ #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_forward_pdps_pair< - F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, + I, + S, + Dat, + Reg, + P, + Z, + R, + Y, + /*KOpM, */ KOpZ, + H, + Plot, + const N: usize, >( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - config : &ForwardPDPSConfig, - iterator : I, - mut plotter : SeqPlotter, + f: &Dat, + reg: &Reg, + prox_penalty: &P, + config: &ForwardPDPSConfig, + iterator: I, + mut plotter: Plot, + (μ0, mut z, mut y): (Option>, Z, Y), //opKμ : KOpM, - opKz : &KOpZ, - fnR : &R, - fnH : &H, - mut z : Z, - mut y : Y, -) -> MeasureZ + opKz: &KOpZ, + fnR: &R, + fnH: &H, +) -> DynResult> where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - A : ForwardModel< - MeasureZ, - F, - PairNorm, - PreadjointCodomain = Pair, - > - + AdjointProductPairBoundedBy, P, IdOp, FloatType=F>, - S: DifferentiableRealMapping, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm, - P : ProxPenalty, - KOpZ : BoundedLinear - + GEMV - + Adjointable, - for<'b> KOpZ::Adjoint<'b> : GEMV, - Y : AXPY + Euclidean + Clone + ClosedAdd, - for<'b> &'b Y : Instance, - Z : AXPY + Euclidean + Clone + Norm, - for<'b> &'b Z : Instance, - R : Prox, - H : Conjugable, - for<'b> H::Conjugate<'b> : Prox, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair>, + //Pair: ClosedMul, // Doesn't really need to be closed, if make this signature more complex… + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: RegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + KOpZ: BoundedLinear + + GEMV + + SimplyAdjointable, + KOpZ::SimpleAdjoint: GEMV, + Y: ClosedEuclidean, + for<'b> &'b Y: Instance, + Z: ClosedEuclidean, + for<'b> &'b Z: Instance, + R: Prox, + H: Conjugable, + for<'b> H::Conjugate<'b>: Prox, + Plot: Plotter>, { - // Check parameters - assert!(config.τ0 > 0.0 && - config.τ0 < 1.0 && - config.σp0 > 0.0 && - config.σp0 < 1.0 && - config.σd0 > 0.0 && - config.σp0 * config.σd0 <= 1.0, - "Invalid step length parameters"); + // ensure!( + // config.τ0 > 0.0 + // && config.τ0 < 1.0 + // && config.σp0 > 0.0 + // && config.σp0 < 1.0 + // && config.σd0 >= 0.0 + // && config.σp0 * config.σd0 <= 1.0, + // "Invalid step length parameters" + // ); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = calculate_residual(Pair(&μ, &z), opA, b); + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); // Set up parameters let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); - let nKz = opKz.opnorm_bound(L2, L2); - let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); + let nKz = opKz.opnorm_bound(L2, L2)?; + let idOpZ = IdOp::new(); + let opKz_adj = opKz.adjoint(); + let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?; // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -137,14 +163,15 @@ // // To do so, we first solve σ_p and σ_d from standard PDPS step length condition // ^^^^^ < 1. then we solve τ from the rest. - let σ_d = config.σd0 / nKz; + // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below. + let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz }; let σ_p = config.σp0 / (l_z + config.σd0 * nKz); // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) let φ = 1.0 - config.σp0; let a = 1.0 - σ_p * l_z; - let τ = config.τ0 * φ / ( σ_d * bigM * a + φ * l ); + let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); // Acceleration is not currently supported // let γ = dataterm.factor_of_strong_convexity(); let ω = 1.0; @@ -157,28 +184,37 @@ let starH = fnH.conjugate(); // Statistics - let full_stats = |residual : &A::Observable, μ : &RNDM, z : &Z, ε, stats| IterInfo { - value : residual.norm2_squared_div2() + fnR.apply(z) - + reg.apply(μ) + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), - n_spikes : μ.len(), + let full_stats = |μ: &RNDM, z: &Z, ε, stats| IterInfo { + value: f.apply(Pair(μ, z)) + + fnR.apply(z) + + reg.apply(μ) + + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), + n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) { // Calculate initial transport - let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); + let Pair(mut τv, τz) = f.differential(Pair(&μ, &z)); let μ_base = μ.clone(); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, &config.insertion, - ®, &state, &mut stats, - ); + &mut μ, + &mut τv, + &μ_base, + None, + τ, + ε, + &config.insertion, + ®, + &state, + &mut stats, + )?; // Merge spikes. // This crucially expects the merge routine to be stable with respect to spike locations, @@ -189,8 +225,9 @@ let ins = &config.insertion; if ins.merge_now(&state) { stats.merged += prox_penalty.merge_spikes_no_fitness( - &mut μ, &mut τv, &μ_base, None, τ, ε, ins, ®, - //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), + &mut μ, &mut τv, &μ_base, None, τ, ε, ins, + ®, + //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), ); } @@ -199,19 +236,16 @@ // Do z variable primal update let mut z_new = τz; - opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); + opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); z_new = fnR.prox(σ_p, z_new + &z); // Do dual update // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] - opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0); + opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b - opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b + opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b y = starH.prox(σ_d, y); z = z_new; - // Update residual - residual = calculate_residual(Pair(&μ, &z), opA, b); - // Update step length parameters // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); @@ -221,20 +255,73 @@ state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); - full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) + full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - let fit = |μ̃ : &RNDM| { - (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() - //+ fnR.apply(z) + reg.apply(μ) + let fit = |μ̃: &RNDM| { + f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/ + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) }; μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); μ.prune(); - Pair(μ, z) + Ok(Pair(μ, z)) } + +/// Iteratively solve the pointsource localisation with an additional variable +/// using forward-backward splitting. +/// +/// The implementation uses [`pointsource_forward_pdps_pair`] with appropriate dummy +/// variables, operators, and functions. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_fb_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + config: &FBConfig, + iterator: I, + plotter: Plot, + (μ0, z): (Option>, Z), + //opKμ : KOpM, + fnR: &R, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair>, + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: RegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + Z: ClosedEuclidean + AXPY + Clone, + for<'b> &'b Z: Instance, + R: Prox, + Plot: Plotter>, + // We should not need to explicitly require this: + for<'b> &'b Loc<0, F>: Instance>, +{ + let opKz = ZeroOp::new_dualisable(Loc([]), z.dual_origin()); + let fnH = Zero::new(); + // Convert config. We don't implement From (that could be done with the o2o crate), as σd0 + // needs to be chosen in a general case; for the problem of this fucntion, anything is valid. + let &FBConfig { τ0, σp0, insertion } = config; + let pdps_config = ForwardPDPSConfig { τ0, σp0, insertion, σd0: 0.0 }; + + pointsource_forward_pdps_pair( + f, + reg, + prox_penalty, + &pdps_config, + iterator, + plotter, + (μ0, z, Loc([])), + &opKz, + fnR, + &fnH, + ) +} diff -r 9738b51d90d7 -r 4f468d35fa29 src/fourier.rs --- a/src/fourier.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/fourier.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,31 +2,32 @@ Fourier transform traits */ -use alg_tools::types::{Num, Float}; -use alg_tools::mapping::{RealMapping, Mapping, Space}; use alg_tools::bisection_tree::Weighted; use alg_tools::loc::Loc; +use alg_tools::mapping::{Mapping, RealMapping, Space}; +use alg_tools::types::{Float, Num}; /// Trait for Fourier transforms. When F is a non-complex number, the transform /// also has to be non-complex, i.e., the function itself symmetric. -pub trait Fourier : Mapping { - type Domain : Space; - type Transformed : Mapping; +pub trait Fourier: Mapping { + type Domain: Space; + type Transformed: Mapping; fn fourier(&self) -> Self::Transformed; } -impl Fourier -for Weighted -where T : Fourier> + RealMapping { +impl Fourier for Weighted +where + T: Fourier> + RealMapping, +{ type Domain = T::Domain; type Transformed = Weighted; #[inline] fn fourier(&self) -> Self::Transformed { Weighted { - base_fn : self.base_fn.fourier(), - weight : self.weight + base_fn: self.base_fn.fourier(), + weight: self.weight, } } } diff -r 9738b51d90d7 -r 4f468d35fa29 src/frank_wolfe.rs --- a/src/frank_wolfe.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/frank_wolfe.rs Thu Feb 26 11:38:43 2026 -0500 @@ -13,82 +13,51 @@ DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). */ +use nalgebra::{DMatrix, DVector}; use numeric_literals::replace_float_literals; -use nalgebra::{DMatrix, DVector}; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; //use colored::Colorize; - -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorOptions, - ValueIteratorFactory, -}; -use alg_tools::euclidean::Euclidean; -use alg_tools::norms::Norm; -use alg_tools::linops::Mapping; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, -}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::L2; - -use crate::types::*; -use crate::measures::{ - RNDM, - DiscreteMeasure, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; +use crate::dataterm::QuadraticDataTerm; use crate::forward_model::ForwardModel; +use crate::measures::merging::{SpikeMerging, SpikeMergingMethod}; +use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM}; +use crate::plot::Plotter; +use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, RegTerm}; #[allow(unused_imports)] // Used in documentation use crate::subproblem::{ - unconstrained::quadratic_unconstrained, - nonneg::quadratic_nonneg, - InnerSettings, - InnerMethod, + nonneg::quadratic_nonneg, unconstrained::quadratic_unconstrained, InnerMethod, InnerSettings, }; use crate::tolerance::Tolerance; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::regularisation::{ - NonnegRadonRegTerm, - RadonRegTerm, - RegTerm -}; +use crate::types::*; +use alg_tools::bisection_tree::P2Minimise; +use alg_tools::bounds::MinMaxMapping; +use alg_tools::error::DynResult; +use alg_tools::euclidean::Euclidean; +use alg_tools::instance::Instance; +use alg_tools::iterate::{AlgIteratorFactory, AlgIteratorOptions, ValueIteratorFactory}; +use alg_tools::linops::Mapping; +use alg_tools::loc::Loc; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::Norm; +use alg_tools::norms::L2; +use alg_tools::sets::Cube; /// Settings for [`pointsource_fw_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct FWConfig { +pub struct FWConfig { /// Tolerance for branch-and-bound new spike location discovery - pub tolerance : Tolerance, + pub tolerance: Tolerance, /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`] /// as the conditional gradient subproblems' optimality conditions do not in general have an /// invertible Newton derivative for SSN. - pub inner : InnerSettings, + pub inner: InnerSettings, /// Variant of the conditional gradient method - pub variant : FWVariant, + pub variant: FWVariant, /// Settings for branch and bound refinement when looking for predual maxima - pub refinement : RefinementSettings, + pub refinement: RefinementSettings, /// Spike merging heuristic - pub merging : SpikeMergingMethod, + pub merging: SpikeMergingMethod, } /// Conditional gradient method variant; see also [`FWConfig`]. @@ -101,51 +70,51 @@ Relaxed, } -impl Default for FWConfig { +impl Default for FWConfig { fn default() -> Self { FWConfig { - tolerance : Default::default(), - refinement : Default::default(), - inner : Default::default(), - variant : FWVariant::FullyCorrective, - merging : SpikeMergingMethod { enabled : true, ..Default::default() }, + tolerance: Default::default(), + refinement: Default::default(), + inner: Default::default(), + variant: FWVariant::FullyCorrective, + merging: SpikeMergingMethod { enabled: true, ..Default::default() }, } } } -pub trait FindimQuadraticModel : ForwardModel, F> +pub trait FindimQuadraticModel: ForwardModel, F> where - F : Float + ToNalgebraRealField, - Domain : Clone + PartialEq, + F: Float + ToNalgebraRealField, + Domain: Clone + PartialEq, { /// Return A_*A and A_* b fn findim_quadratic_model( &self, - μ : &DiscreteMeasure, - b : &Self::Observable + μ: &DiscreteMeasure, + b: &Self::Observable, ) -> (DMatrix, DVector); } /// Helper struct for pre-initialising the finite-dimensional subproblem solver. -pub struct FindimData { +pub struct FindimData { /// ‖A‖^2 - opAnorm_squared : F, + opAnorm_squared: F, /// Bound $M_0$ from the Bredies–Pikkarainen article. - m0 : F + m0: F, } /// Trait for finite dimensional weight optimisation. pub trait WeightOptim< - F : Float + ToNalgebraRealField, - A : ForwardModel, F>, - I : AlgIteratorFactory, - const N : usize -> { - + F: Float + ToNalgebraRealField, + A: ForwardModel, F>, + I: AlgIteratorFactory, + const N: usize, +> +{ /// Return a pre-initialisation struct for [`Self::optimise_weights`]. /// /// The parameter `opA` is the forward operator $A$. - fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData; + fn prepare_optimise_weights(&self, opA: &A, b: &A::Observable) -> DynResult>; /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity /// point source localisation problem. @@ -166,72 +135,70 @@ /// Returns the number of iterations taken by the method configured in `inner`. fn optimise_weights<'a>( &self, - μ : &mut RNDM, - opA : &'a A, - b : &A::Observable, - findim_data : &FindimData, - inner : &InnerSettings, - iterator : I + μ: &mut RNDM, + opA: &'a A, + b: &A::Observable, + findim_data: &FindimData, + inner: &InnerSettings, + iterator: I, ) -> usize; } /// Trait for regularisation terms supported by [`pointsource_fw_reg`]. pub trait RegTermFW< - F : Float + ToNalgebraRealField, - A : ForwardModel, F>, - I : AlgIteratorFactory, - const N : usize -> : RegTerm - + WeightOptim - + Mapping, Codomain = F> { - + F: Float + ToNalgebraRealField, + A: ForwardModel, F>, + I: AlgIteratorFactory, + const N: usize, +>: RegTerm, F> + WeightOptim + Mapping, Codomain = F> +{ /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted /// into $μ$, as determined by the regulariser. /// /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant - /// [`BTFN`] minimisation and maximisation routines. + /// [`MinMaxMapping`] minimisation and maximisation routines. fn find_insertion( &self, - g : &mut A::PreadjointCodomain, - refinement_tolerance : F, - max_steps : usize - ) -> (Loc, F); + g: &mut A::PreadjointCodomain, + refinement_tolerance: F, + max_steps: usize, + ) -> (Loc, F); /// Insert point `ξ` into `μ` for the relaxed algorithm from Bredies–Pikkarainen. fn relaxed_insert<'a>( &self, - μ : &mut RNDM, - g : &A::PreadjointCodomain, - opA : &'a A, - ξ : Loc, - v_ξ : F, - findim_data : &FindimData + μ: &mut RNDM, + g: &A::PreadjointCodomain, + opA: &'a A, + ξ: Loc, + v_ξ: F, + findim_data: &FindimData, ); } #[replace_float_literals(F::cast_from(literal))] -impl WeightOptim -for RadonRegTerm -where I : AlgIteratorFactory, - A : FindimQuadraticModel, F> { - - fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData { - FindimData{ - opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2), - m0 : b.norm2_squared() / (2.0 * self.α()), - } +impl WeightOptim + for RadonRegTerm +where + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, +{ + fn prepare_optimise_weights(&self, opA: &A, b: &A::Observable) -> DynResult> { + Ok(FindimData { + opAnorm_squared: opA.opnorm_bound(Radon, L2)?.powi(2), + m0: b.norm2_squared() / (2.0 * self.α()), + }) } fn optimise_weights<'a>( &self, - μ : &mut RNDM, - opA : &'a A, - b : &A::Observable, - findim_data : &FindimData, - inner : &InnerSettings, - iterator : I + μ: &mut RNDM, + opA: &'a A, + b: &A::Observable, + findim_data: &FindimData, + inner: &InnerSettings, + iterator: I, ) -> usize { - // Form and solve finite-dimensional subproblem. let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); let mut x = μ.masses_dvector(); @@ -245,8 +212,7 @@ // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no // square root is needed when we scale: let normest = findim_data.opAnorm_squared * F::cast_from(μ.len()); - let iters = quadratic_unconstrained(&Ã, &g̃, self.α(), &mut x, - normest, inner, iterator); + let iters = quadratic_unconstrained(&Ã, &g̃, self.α(), &mut x, normest, inner, iterator); // Update masses of μ based on solution of finite-dimensional subproblem. μ.set_masses_dvector(&x); @@ -255,28 +221,23 @@ } #[replace_float_literals(F::cast_from(literal))] -impl RegTermFW -for RadonRegTerm +impl RegTermFW for RadonRegTerm where - Cube : P2Minimise, F>, - I : AlgIteratorFactory, - S: RealMapping + LocalAnalysis, N>, - GA : SupportGenerator + Clone, - A : FindimQuadraticModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, + Cube: P2Minimise, F>, + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + for<'a> &'a A::PreadjointCodomain: Instance, // FIXME: the following *should not* be needed, they are already implied - RNDM : Mapping, - DeltaMeasure, F> : Mapping, - //A : Mapping, Codomain = A::Observable>, - //A : Mapping, F>, Codomain = A::Observable>, + RNDM: Mapping, + DeltaMeasure, F>: Mapping, { - fn find_insertion( &self, - g : &mut A::PreadjointCodomain, - refinement_tolerance : F, - max_steps : usize - ) -> (Loc, F) { + g: &mut A::PreadjointCodomain, + refinement_tolerance: F, + max_steps: usize, + ) -> (Loc, F) { let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, max_steps); let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, max_steps); if v_ξmin < 0.0 && -v_ξmin > v_ξmax { @@ -288,25 +249,35 @@ fn relaxed_insert<'a>( &self, - μ : &mut RNDM, - g : &A::PreadjointCodomain, - opA : &'a A, - ξ : Loc, - v_ξ : F, - findim_data : &FindimData + μ: &mut RNDM, + g: &A::PreadjointCodomain, + opA: &'a A, + ξ: Loc, + v_ξ: F, + findim_data: &FindimData, ) { let α = self.0; let m0 = findim_data.m0; - let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; - let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; - let δ = DeltaMeasure { x : ξ, α : v }; + let φ = |t| { + if t <= m0 { + α * t + } else { + α / (2.0 * m0) * (t * t + m0 * m0) + } + }; + let v = if v_ξ.abs() <= α { + 0.0 + } else { + m0 / α * v_ξ + }; + let δ = DeltaMeasure { x: ξ, α: v }; let dp = μ.apply(g) - δ.apply(g); let d = opA.apply(&*μ) - opA.apply(δ); let r = d.norm2_squared(); let s = if r == 0.0 { 1.0 } else { - 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) + 1.0.min((α * μ.norm(Radon) - φ(v.abs()) - dp) / r) }; *μ *= 1.0 - s; *μ += δ * s; @@ -314,28 +285,28 @@ } #[replace_float_literals(F::cast_from(literal))] -impl WeightOptim -for NonnegRadonRegTerm -where I : AlgIteratorFactory, - A : FindimQuadraticModel, F> { - - fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData { - FindimData{ - opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2), - m0 : b.norm2_squared() / (2.0 * self.α()), - } +impl WeightOptim + for NonnegRadonRegTerm +where + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, +{ + fn prepare_optimise_weights(&self, opA: &A, b: &A::Observable) -> DynResult> { + Ok(FindimData { + opAnorm_squared: opA.opnorm_bound(Radon, L2)?.powi(2), + m0: b.norm2_squared() / (2.0 * self.α()), + }) } fn optimise_weights<'a>( &self, - μ : &mut RNDM, - opA : &'a A, - b : &A::Observable, - findim_data : &FindimData, - inner : &InnerSettings, - iterator : I + μ: &mut RNDM, + opA: &'a A, + b: &A::Observable, + findim_data: &FindimData, + inner: &InnerSettings, + iterator: I, ) -> usize { - // Form and solve finite-dimensional subproblem. let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); let mut x = μ.masses_dvector(); @@ -349,8 +320,7 @@ // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no // square root is needed when we scale: let normest = findim_data.opAnorm_squared * F::cast_from(μ.len()); - let iters = quadratic_nonneg(&Ã, &g̃, self.α(), &mut x, - normest, inner, iterator); + let iters = quadratic_nonneg(&Ã, &g̃, self.α(), &mut x, normest, inner, iterator); // Update masses of μ based on solution of finite-dimensional subproblem. μ.set_masses_dvector(&x); @@ -359,59 +329,65 @@ } #[replace_float_literals(F::cast_from(literal))] -impl RegTermFW -for NonnegRadonRegTerm +impl RegTermFW + for NonnegRadonRegTerm where - Cube : P2Minimise, F>, - I : AlgIteratorFactory, - S: RealMapping + LocalAnalysis, N>, - GA : SupportGenerator + Clone, - A : FindimQuadraticModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, + Cube: P2Minimise, F>, + I: AlgIteratorFactory, + A: FindimQuadraticModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + for<'a> &'a A::PreadjointCodomain: Instance, // FIXME: the following *should not* be needed, they are already implied - RNDM : Mapping, - DeltaMeasure, F> : Mapping, + RNDM: Mapping, + DeltaMeasure, F>: Mapping, { - fn find_insertion( &self, - g : &mut A::PreadjointCodomain, - refinement_tolerance : F, - max_steps : usize - ) -> (Loc, F) { + g: &mut A::PreadjointCodomain, + refinement_tolerance: F, + max_steps: usize, + ) -> (Loc, F) { g.maximise(refinement_tolerance, max_steps) } - fn relaxed_insert<'a>( &self, - μ : &mut RNDM, - g : &A::PreadjointCodomain, - opA : &'a A, - ξ : Loc, - v_ξ : F, - findim_data : &FindimData + μ: &mut RNDM, + g: &A::PreadjointCodomain, + opA: &'a A, + ξ: Loc, + v_ξ: F, + findim_data: &FindimData, ) { // This is just a verbatim copy of RadonRegTerm::relaxed_insert. let α = self.0; let m0 = findim_data.m0; - let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; - let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; - let δ = DeltaMeasure { x : ξ, α : v }; + let φ = |t| { + if t <= m0 { + α * t + } else { + α / (2.0 * m0) * (t * t + m0 * m0) + } + }; + let v = if v_ξ.abs() <= α { + 0.0 + } else { + m0 / α * v_ξ + }; + let δ = DeltaMeasure { x: ξ, α: v }; let dp = μ.apply(g) - δ.apply(g); let d = opA.apply(&*μ) - opA.apply(&δ); let r = d.norm2_squared(); let s = if r == 0.0 { 1.0 } else { - 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) + 1.0.min((α * μ.norm(Radon) - φ(v.abs()) - dp) / r) }; *μ *= 1.0 - s; *μ += δ * s; } } - /// Solve point source localisation problem using a conditional gradient method /// for the 2-norm-squared data fidelity, i.e., the problem ///

$$ @@ -425,49 +401,48 @@ /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to /// save intermediate iteration states as images. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fw_reg( - opA : &A, - b : &A::Observable, - reg : Reg, - //domain : Cube, - config : &FWConfig, - iterator : I, - mut plotter : SeqPlotter, -) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTermFW, N> { +pub fn pointsource_fw_reg<'a, F, I, A, Reg, Plot, const N: usize>( + f: &'a QuadraticDataTerm, A>, + reg: &Reg, + //domain : Cube, + config: &FWConfig, + iterator: I, + mut plotter: Plot, + μ0 : Option>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + &'a A::PreadjointCodomain: Instance, + Cube: P2Minimise, F>, + RNDM: SpikeMerging, + Reg: RegTermFW, N>, + Plot: Plotter>, +{ + let opA = f.operator(); + let b = f.data(); // Set up parameters // We multiply tolerance by α for all algoritms. let tolerance = config.tolerance * reg.tolerance_scaling(); let mut ε = tolerance.initial(); - let findim_data = reg.prepare_optimise_weights(opA, b); + let findim_data = reg.prepare_optimise_weights(opA, b)?; // Initialise operators let preadjA = opA.preadjoint(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); + let mut residual = f.residual(&μ); // Statistics - let full_stats = |residual : &A::Observable, - ν : &RNDM, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(ν), - n_spikes : ν.len(), + let full_stats = |residual: &A::Observable, ν: &RNDM, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + reg.apply(ν), + n_spikes: ν.len(), ε, - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -480,32 +455,34 @@ let mut g = preadjA.apply(residual * (-1.0)); // Find absolute value maximising point - let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance, - config.refinement.max_steps); + let (ξ, v_ξ) = + reg.find_insertion(&mut g, refinement_tolerance, config.refinement.max_steps); let inner_it = match config.variant { FWVariant::FullyCorrective => { // No point in optimising the weight here: the finite-dimensional algorithm is fast. - μ += DeltaMeasure { x : ξ, α : 0.0 }; + μ += DeltaMeasure { x: ξ, α: 0.0 }; stats.inserted += 1; config.inner.iterator_options.stop_target(inner_tolerance) - }, + } FWVariant::Relaxed => { // Perform a relaxed initialisation of μ reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data); stats.inserted += 1; // The stop_target is only needed for the type system. - AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) + AlgIteratorOptions { max_iter: 1, ..config.inner.iterator_options }.stop_target(0.0) } }; - stats.inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data, - &config.inner, inner_it); - + stats.inner_iters += + reg.optimise_weights(&mut μ, opA, b, &findim_data, &config.inner, inner_it); + // Merge spikes and update residual for next step and `if_verbose` below. - let (r, count) = μ.merge_spikes_fitness(config.merging, - |μ̃| opA.apply(μ̃) - b, - A::Observable::norm2_squared); + let (r, count) = μ.merge_spikes_fitness( + config.merging, + |μ̃| f.residual(μ̃), + A::Observable::norm2_squared, + ); residual = r; stats.merged += count; @@ -520,8 +497,13 @@ // Give statistics if needed state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&g), Option::<&S>::None, &μ); - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) + plotter.plot_spikes(iter, Some(&g), Option::<&A::PreadjointCodomain>::None, &μ); + full_stats( + &residual, + &μ, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) }); // Update tolerance @@ -529,5 +511,5 @@ } // Return final iterate - μ + Ok(μ) } diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/ball_indicator.rs --- a/src/kernels/ball_indicator.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/ball_indicator.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,56 +1,44 @@ - //! Implementation of the indicator function of a ball with respect to various norms. +use super::base::*; +use crate::types::*; +use alg_tools::bisection_tree::{Bounds, Constant, GlobalAnalysis, LocalAnalysis, Support}; +use alg_tools::coefficients::factorial; +use alg_tools::euclidean::StaticEuclidean; +use alg_tools::instance::Instance; +use alg_tools::loc::Loc; +use alg_tools::mapping::{DifferentiableImpl, Differential, LipschitzDifferentiableImpl, Mapping}; +use alg_tools::maputil::array_init; +use alg_tools::norms::*; +use alg_tools::sets::Cube; +use anyhow::anyhow; use float_extras::f64::tgamma as gamma; use numeric_literals::replace_float_literals; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::norms::*; -use alg_tools::loc::Loc; -use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Constant, - Bounds, - LocalAnalysis, - GlobalAnalysis, -}; -use alg_tools::mapping::{ - Mapping, - Differential, - DifferentiableImpl, -}; -use alg_tools::instance::Instance; -use alg_tools::euclidean::StaticEuclidean; -use alg_tools::maputil::array_init; -use alg_tools::coefficients::factorial; -use crate::types::*; -use super::base::*; /// Representation of the indicator of the ball $𝔹_q = \\{ x ∈ ℝ^N \mid \\|x\\|\_q ≤ r \\}$, /// where $q$ is the `Exponent`, and $r$ is the radius [`Constant`] `C`. -#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] -pub struct BallIndicator { +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] +pub struct BallIndicator { /// The radius of the ball. - pub r : C, + pub r: C, /// The exponent $q$ of the norm creating the ball - pub exponent : Exponent, + pub exponent: Exponent, } /// Alias for the representation of the indicator of the $∞$-norm-ball /// $𝔹_∞ = \\{ x ∈ ℝ^N \mid \\|x\\|\_∞ ≤ c \\}$. -pub type CubeIndicator = BallIndicator; +pub type CubeIndicator = BallIndicator; #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -Mapping> -for BallIndicator +impl<'a, F: Float, C: Constant, Exponent: NormExponent, const N: usize> + Mapping> for BallIndicator where - Loc : Norm + Loc: Norm, { type Codomain = C::Type; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { let r = self.r.value(); let n = x.eval(|x| x.norm(self.exponent)); if n <= r { @@ -61,114 +49,105 @@ } } -impl<'a, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -DifferentiableImpl> -for BallIndicator +impl<'a, F: Float, C: Constant, Exponent: NormExponent, const N: usize> + DifferentiableImpl> for BallIndicator where - C : Constant, - Loc : Norm + C: Constant, + Loc: Norm, { - type Derivative = Loc; + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, _x : I) -> Self::Derivative { + fn differential_impl>>(&self, _x: I) -> Self::Derivative { Self::Derivative::origin() } } -impl, Exponent : NormExponent, const N : usize> -Lipschitz -for BallIndicator -where C : Constant, - Loc : Norm { +impl, Exponent: NormExponent, const N: usize> Lipschitz + for BallIndicator +where + C: Constant, + Loc: Norm, +{ type FloatType = C::Type; - fn lipschitz_factor(&self, _l2 : L2) -> Option { - None + fn lipschitz_factor(&self, _l2: L2) -> DynResult { + Err(anyhow!("Not a Lipschitz function")) } } -impl<'b, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -Lipschitz -for Differential<'b, Loc, BallIndicator> -where C : Constant, - Loc : Norm { +impl<'b, F: Float, C: Constant, Exponent: NormExponent, const N: usize> + LipschitzDifferentiableImpl, L2> for BallIndicator +where + C: Constant, + Loc: Norm, +{ type FloatType = C::Type; - fn lipschitz_factor(&self, _l2 : L2) -> Option { - None + fn diff_lipschitz_factor(&self, _l2: L2) -> DynResult { + Err(anyhow!("Not a Lipschitz-differentiable function")) } } -impl<'a, 'b, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -Lipschitz -for Differential<'b, Loc, &'a BallIndicator> -where C : Constant, - Loc : Norm { +impl<'b, F: Float, C: Constant, Exponent: NormExponent, const N: usize> NormBounded + for Differential<'b, Loc, BallIndicator> +where + C: Constant, + Loc: Norm, +{ type FloatType = C::Type; - fn lipschitz_factor(&self, _l2 : L2) -> Option { - None - } -} - - -impl<'b, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -NormBounded -for Differential<'b, Loc, BallIndicator> -where C : Constant, - Loc : Norm { - type FloatType = C::Type; - - fn norm_bound(&self, _l2 : L2) -> C::Type { + fn norm_bound(&self, _l2: L2) -> C::Type { F::INFINITY } } -impl<'a, 'b, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -NormBounded -for Differential<'b, Loc, &'a BallIndicator> -where C : Constant, - Loc : Norm { +impl<'a, 'b, F: Float, C: Constant, Exponent: NormExponent, const N: usize> + NormBounded for Differential<'b, Loc, &'a BallIndicator> +where + C: Constant, + Loc: Norm, +{ type FloatType = C::Type; - fn norm_bound(&self, _l2 : L2) -> C::Type { + fn norm_bound(&self, _l2: L2) -> C::Type { F::INFINITY } } - -impl<'a, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -Support -for BallIndicator -where Loc : Norm, - Linfinity : Dominated> { - +impl<'a, F: Float, C: Constant, Exponent, const N: usize> Support + for BallIndicator +where + Exponent: NormExponent + Sync + Send + 'static, + Loc: Norm, + Linfinity: Dominated>, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let r = Linfinity.from_norm(self.r.value(), self.exponent); array_init(|| [-r, r]).into() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { let r = Linfinity.from_norm(self.r.value(), self.exponent); x.norm(self.exponent) <= r } /// This can only really work in a reasonable fashion for N=1. #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { let r = Linfinity.from_norm(self.r.value(), self.exponent); cube.map(|a, b| symmetric_interval_hint(r, a, b)) } } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -GlobalAnalysis> -for BallIndicator -where Loc : Norm { +impl<'a, F: Float, C: Constant, Exponent: NormExponent, const N: usize> + GlobalAnalysis> for BallIndicator +where + Loc: Norm, +{ #[inline] fn global_analysis(&self) -> Bounds { Bounds(0.0, 1.0) @@ -176,29 +155,28 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, C : Constant, Exponent : NormExponent, const N : usize> -Norm -for BallIndicator -where Loc : Norm { +impl<'a, F: Float, C: Constant, Exponent: NormExponent, const N: usize> Norm + for BallIndicator +where + Loc: Norm, +{ #[inline] - fn norm(&self, _ : Linfinity) -> F { + fn norm(&self, _: Linfinity) -> F { 1.0 } } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, C : Constant, const N : usize> -Norm -for BallIndicator { +impl<'a, F: Float, C: Constant, const N: usize> Norm for BallIndicator { #[inline] - fn norm(&self, _ : L1) -> F { + fn norm(&self, _: L1) -> F { // Using https://en.wikipedia.org/wiki/Volume_of_an_n-ball#Balls_in_Lp_norms, // we have V_N^1(r) = (2r)^N / N! let r = self.r.value(); - if N==1 { + if N == 1 { 2.0 * r - } else if N==2 { - r*r + } else if N == 2 { + r * r } else { (2.0 * r).powi(N as i32) * F::cast_from(factorial(N)) } @@ -206,17 +184,15 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, C : Constant, const N : usize> -Norm -for BallIndicator { +impl<'a, F: Float, C: Constant, const N: usize> Norm for BallIndicator { #[inline] - fn norm(&self, _ : L1) -> F { + fn norm(&self, _: L1) -> F { // See https://en.wikipedia.org/wiki/Volume_of_an_n-ball#The_volume. let r = self.r.value(); let π = F::PI; - if N==1 { + if N == 1 { 2.0 * r - } else if N==2 { + } else if N == 2 { π * (r * r) } else { let ndiv2 = F::cast_from(N) / 2.0; @@ -227,94 +203,100 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, C : Constant, const N : usize> -Norm -for BallIndicator { +impl<'a, F: Float, C: Constant, const N: usize> Norm + for BallIndicator +{ #[inline] - fn norm(&self, _ : L1) -> F { + fn norm(&self, _: L1) -> F { let two_r = 2.0 * self.r.value(); two_r.powi(N as i32) } } - macro_rules! indicator_local_analysis { ($exponent:ident) => { - impl<'a, F : Float, C : Constant, const N : usize> - LocalAnalysis, N> - for BallIndicator - where Loc : Norm, - Linfinity : Dominated> { + impl<'a, F: Float, C: Constant, const N: usize> LocalAnalysis, N> + for BallIndicator + where + Loc: Norm<$exponent, F>, + Linfinity: Dominated>, + { #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the 2-norm is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); Bounds(lower, upper) } } - } + }; } indicator_local_analysis!(L1); indicator_local_analysis!(L2); indicator_local_analysis!(Linfinity); - #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, R, const N : usize> Mapping> -for AutoConvolution> -where R : Constant { +impl<'a, F: Float, R, const N: usize> Mapping> for AutoConvolution> +where + R: Constant, +{ type Codomain = F; #[inline] - fn apply>>(&self, y : I) -> F { + fn apply>>(&self, y: I) -> F { let two_r = 2.0 * self.0.r.value(); // This is just a product of one-dimensional versions - y.cow().iter().map(|&x| { - 0.0.max(two_r - x.abs()) - }).product() + y.decompose() + .iter() + .map(|&x| 0.0.max(two_r - x.abs())) + .product() } } #[replace_float_literals(F::cast_from(literal))] -impl Support -for AutoConvolution> -where R : Constant { +impl Support for AutoConvolution> +where + R: Constant, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let two_r = 2.0 * self.0.r.value(); array_init(|| [-two_r, two_r]).into() } #[inline] - fn in_support(&self, y : &Loc) -> bool { + fn in_support(&self, y: &Loc) -> bool { let two_r = 2.0 * self.0.r.value(); y.iter().all(|x| x.abs() <= two_r) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { let two_r = 2.0 * self.0.r.value(); cube.map(|c, d| symmetric_interval_hint(two_r, c, d)) } } #[replace_float_literals(F::cast_from(literal))] -impl GlobalAnalysis> -for AutoConvolution> -where R : Constant { +impl GlobalAnalysis> + for AutoConvolution> +where + R: Constant, +{ #[inline] fn global_analysis(&self) -> Bounds { Bounds(0.0, self.apply(Loc::ORIGIN)) } } -impl LocalAnalysis, N> -for AutoConvolution> -where R : Constant { +impl LocalAnalysis, N> + for AutoConvolution> +where + R: Constant, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the absolute value is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/base.rs --- a/src/kernels/base.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/base.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,28 +1,20 @@ - //! Things for constructing new kernels from component kernels and traits for analysing them -use serde::Serialize; use numeric_literals::replace_float_literals; +use serde::Serialize; -use alg_tools::types::*; +use alg_tools::bisection_tree::Support; +use alg_tools::bounds::{Bounded, Bounds, GlobalAnalysis, LocalAnalysis}; +use alg_tools::instance::{Instance, Space}; +use alg_tools::loc::Loc; +use alg_tools::mapping::{ + DifferentiableImpl, DifferentiableMapping, LipschitzDifferentiableImpl, Mapping, +}; +use alg_tools::maputil::{array_init, map1_indexed, map2}; use alg_tools::norms::*; -use alg_tools::loc::Loc; use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Bounds, - LocalAnalysis, - GlobalAnalysis, - Bounded, -}; -use alg_tools::mapping::{ - Mapping, - DifferentiableImpl, - DifferentiableMapping, - Differential, -}; -use alg_tools::instance::{Instance, Space}; -use alg_tools::maputil::{array_init, map2, map1_indexed}; use alg_tools::sets::SetOrd; +use alg_tools::types::*; +use anyhow::anyhow; use crate::fourier::Fourier; use crate::types::*; @@ -32,134 +24,129 @@ /// The kernels typically implement [`Support`] and [`Mapping`]. /// /// The implementation [`Support`] only uses the [`Support::support_hint`] of the first parameter! -#[derive(Copy,Clone,Serialize,Debug)] +#[derive(Copy, Clone, Serialize, Debug)] pub struct SupportProductFirst( /// First kernel pub A, /// Second kernel - pub B + pub B, ); -impl Mapping> -for SupportProductFirst +impl Mapping> for SupportProductFirst where - A : Mapping, Codomain = F>, - B : Mapping, Codomain = F>, + A: Mapping, Codomain = F>, + B: Mapping, Codomain = F>, { type Codomain = F; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { - self.0.apply(x.ref_instance()) * self.1.apply(x) + fn apply>>(&self, x: I) -> Self::Codomain { + x.eval_ref(|r| self.0.apply(r)) * self.1.apply(x) } } -impl DifferentiableImpl> -for SupportProductFirst +impl DifferentiableImpl> for SupportProductFirst where - A : DifferentiableMapping< - Loc, - DerivativeDomain=Loc, - Codomain = F - >, - B : DifferentiableMapping< - Loc, - DerivativeDomain=Loc, - Codomain = F, - > + A: DifferentiableMapping, DerivativeDomain = Loc, Codomain = F>, + B: DifferentiableMapping, DerivativeDomain = Loc, Codomain = F>, { - type Derivative = Loc; + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, x : I) -> Self::Derivative { - let xr = x.ref_instance(); - self.0.differential(xr) * self.1.apply(xr) + self.1.differential(xr) * self.0.apply(x) + fn differential_impl>>(&self, x: I) -> Self::Derivative { + x.eval_ref(|xr| { + self.0.differential(xr) * self.1.apply(xr) + self.1.differential(xr) * self.0.apply(xr) + }) } } -impl Lipschitz -for SupportProductFirst -where A : Lipschitz + Bounded, - B : Lipschitz + Bounded { +impl Lipschitz for SupportProductFirst +where + A: Lipschitz + Bounded, + B: Lipschitz + Bounded, +{ type FloatType = F; #[inline] - fn lipschitz_factor(&self, m : M) -> Option { + fn lipschitz_factor(&self, m: M) -> DynResult { // f(x)g(x) - f(y)g(y) = f(x)[g(x)-g(y)] - [f(y)-f(x)]g(y) let &SupportProductFirst(ref f, ref g) = self; - f.lipschitz_factor(m).map(|l| l * g.bounds().uniform()) - .zip(g.lipschitz_factor(m).map(|l| l * f.bounds().uniform())) - .map(|(a, b)| a + b) + Ok(f.lipschitz_factor(m)? * g.bounds().uniform() + + g.lipschitz_factor(m)? * f.bounds().uniform()) } } -impl<'a, A, B, M : Copy, Domain, F : Float> Lipschitz -for Differential<'a, Domain, SupportProductFirst> +impl<'a, A, B, M: Copy, Domain, F: Float> LipschitzDifferentiableImpl + for SupportProductFirst where - Domain : Space, - A : Clone + DifferentiableMapping + Lipschitz + Bounded, - B : Clone + DifferentiableMapping + Lipschitz + Bounded, - SupportProductFirst : DifferentiableMapping, - for<'b> A::Differential<'b> : Lipschitz + NormBounded, - for<'b> B::Differential<'b> : Lipschitz + NormBounded + Domain: Space, + Self: DifferentiableImpl, + A: DifferentiableMapping + Lipschitz + Bounded, + B: DifferentiableMapping + Lipschitz + Bounded, + SupportProductFirst: DifferentiableMapping, + for<'b> A::Differential<'b>: Lipschitz + NormBounded, + for<'b> B::Differential<'b>: Lipschitz + NormBounded, { type FloatType = F; #[inline] - fn lipschitz_factor(&self, m : M) -> Option { + fn diff_lipschitz_factor(&self, m: M) -> DynResult { // ∇[gf] = f∇g + g∇f // ⟹ ∇[gf](x) - ∇[gf](y) = f(x)∇g(x) + g(x)∇f(x) - f(y)∇g(y) + g(y)∇f(y) // = f(x)[∇g(x)-∇g(y)] + g(x)∇f(x) - [f(y)-f(x)]∇g(y) + g(y)∇f(y) // = f(x)[∇g(x)-∇g(y)] + g(x)[∇f(x)-∇f(y)] // - [f(y)-f(x)]∇g(y) + [g(y)-g(x)]∇f(y) - let &SupportProductFirst(ref f, ref g) = self.base_fn(); + let &SupportProductFirst(ref f, ref g) = self; let (df, dg) = (f.diff_ref(), g.diff_ref()); - [ - df.lipschitz_factor(m).map(|l| l * g.bounds().uniform()), - dg.lipschitz_factor(m).map(|l| l * f.bounds().uniform()), - f.lipschitz_factor(m).map(|l| l * dg.norm_bound(L2)), - g.lipschitz_factor(m).map(|l| l * df.norm_bound(L2)) - ].into_iter().sum() + Ok([ + df.lipschitz_factor(m)? * g.bounds().uniform(), + dg.lipschitz_factor(m)? * f.bounds().uniform(), + f.lipschitz_factor(m)? * dg.norm_bound(L2), + g.lipschitz_factor(m)? * df.norm_bound(L2), + ] + .into_iter() + .sum()) } } - -impl<'a, A, B, F : Float, const N : usize> Support -for SupportProductFirst +impl<'a, A, B, F: Float, const N: usize> Support for SupportProductFirst where - A : Support, - B : Support + A: Support, + B: Support, { #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { self.0.support_hint() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { self.0.in_support(x) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { self.0.bisection_hint(cube) } } -impl<'a, A, B, F : Float> GlobalAnalysis> -for SupportProductFirst -where A : GlobalAnalysis>, - B : GlobalAnalysis> { +impl<'a, A, B, F: Float> GlobalAnalysis> for SupportProductFirst +where + A: GlobalAnalysis>, + B: GlobalAnalysis>, +{ #[inline] fn global_analysis(&self) -> Bounds { self.0.global_analysis() * self.1.global_analysis() } } -impl<'a, A, B, F : Float, const N : usize> LocalAnalysis, N> -for SupportProductFirst -where A : LocalAnalysis, N>, - B : LocalAnalysis, N> { +impl<'a, A, B, F: Float, const N: usize> LocalAnalysis, N> + for SupportProductFirst +where + A: LocalAnalysis, N>, + B: LocalAnalysis, N>, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { self.0.local_analysis(cube) * self.1.local_analysis(cube) } } @@ -169,125 +156,116 @@ /// The kernels typically implement [`Support`] and [`Mapping`]. /// /// The implementation [`Support`] only uses the [`Support::support_hint`] of the first parameter! -#[derive(Copy,Clone,Serialize,Debug)] +#[derive(Copy, Clone, Serialize, Debug)] pub struct SupportSum( /// First kernel pub A, /// Second kernel - pub B + pub B, ); -impl<'a, A, B, F : Float, const N : usize> Mapping> -for SupportSum +impl<'a, A, B, F: Float, const N: usize> Mapping> for SupportSum where - A : Mapping, Codomain = F>, - B : Mapping, Codomain = F>, + A: Mapping, Codomain = F>, + B: Mapping, Codomain = F>, { type Codomain = F; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { - self.0.apply(x.ref_instance()) + self.1.apply(x) + fn apply>>(&self, x: I) -> Self::Codomain { + x.eval_ref(|r| self.0.apply(r)) + self.1.apply(x) } } -impl<'a, A, B, F : Float, const N : usize> DifferentiableImpl> -for SupportSum +impl<'a, A, B, F: Float, const N: usize> DifferentiableImpl> for SupportSum where - A : DifferentiableMapping< - Loc, - DerivativeDomain = Loc - >, - B : DifferentiableMapping< - Loc, - DerivativeDomain = Loc, - > + A: DifferentiableMapping, DerivativeDomain = Loc>, + B: DifferentiableMapping, DerivativeDomain = Loc>, { - - type Derivative = Loc; + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, x : I) -> Self::Derivative { - self.0.differential(x.ref_instance()) + self.1.differential(x) + fn differential_impl>>(&self, x: I) -> Self::Derivative { + x.eval_ref(|r| self.0.differential(r)) + self.1.differential(x) } } - -impl<'a, A, B, F : Float, const N : usize> Support -for SupportSum -where A : Support, - B : Support, - Cube : SetOrd { - +impl<'a, A, B, F: Float, const N: usize> Support for SupportSum +where + A: Support, + B: Support, + Cube: SetOrd, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { self.0.support_hint().common(&self.1.support_hint()) } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { self.0.in_support(x) || self.1.in_support(x) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { - map2(self.0.bisection_hint(cube), - self.1.bisection_hint(cube), - |a, b| a.or(b)) + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { + map2( + self.0.bisection_hint(cube), + self.1.bisection_hint(cube), + |a, b| a.or(b), + ) } } -impl<'a, A, B, F : Float> GlobalAnalysis> -for SupportSum -where A : GlobalAnalysis>, - B : GlobalAnalysis> { +impl<'a, A, B, F: Float> GlobalAnalysis> for SupportSum +where + A: GlobalAnalysis>, + B: GlobalAnalysis>, +{ #[inline] fn global_analysis(&self) -> Bounds { self.0.global_analysis() + self.1.global_analysis() } } -impl<'a, A, B, F : Float, const N : usize> LocalAnalysis, N> -for SupportSum -where A : LocalAnalysis, N>, - B : LocalAnalysis, N>, - Cube : SetOrd { +impl<'a, A, B, F: Float, const N: usize> LocalAnalysis, N> for SupportSum +where + A: LocalAnalysis, N>, + B: LocalAnalysis, N>, + Cube: SetOrd, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { self.0.local_analysis(cube) + self.1.local_analysis(cube) } } -impl Lipschitz for SupportSum -where A : Lipschitz, - B : Lipschitz { +impl Lipschitz for SupportSum +where + A: Lipschitz, + B: Lipschitz, +{ type FloatType = F; - fn lipschitz_factor(&self, m : M) -> Option { - match (self.0.lipschitz_factor(m), self.1.lipschitz_factor(m)) { - (Some(l0), Some(l1)) => Some(l0 + l1), - _ => None - } + fn lipschitz_factor(&self, m: M) -> DynResult { + Ok(self.0.lipschitz_factor(m)? * self.1.lipschitz_factor(m)?) } } -impl<'b, F : Float, M : Copy, A, B, Domain> Lipschitz -for Differential<'b, Domain, SupportSum> +impl<'b, F: Float, M: Copy, A, B, Domain> LipschitzDifferentiableImpl + for SupportSum where - Domain : Space, - A : Clone + DifferentiableMapping, - B : Clone + DifferentiableMapping, - SupportSum : DifferentiableMapping, - for<'a> A :: Differential<'a> : Lipschitz, - for<'a> B :: Differential<'a> : Lipschitz + Domain: Space, + Self: DifferentiableImpl, + A: DifferentiableMapping, + B: DifferentiableMapping, + SupportSum: DifferentiableMapping, + for<'a> A::Differential<'a>: Lipschitz, + for<'a> B::Differential<'a>: Lipschitz, { type FloatType = F; - fn lipschitz_factor(&self, m : M) -> Option { - let base = self.base_fn(); - base.0.diff_ref().lipschitz_factor(m) - .zip(base.1.diff_ref().lipschitz_factor(m)) - .map(|(a, b)| a + b) + fn diff_lipschitz_factor(&self, m: M) -> DynResult { + Ok(self.0.diff_ref().lipschitz_factor(m)? + self.1.diff_ref().lipschitz_factor(m)?) } } @@ -296,48 +274,52 @@ /// The kernels typically implement [`Support`]s and [`Mapping`]. // /// Trait implementations have to be on a case-by-case basis. -#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] pub struct Convolution( /// First kernel pub A, /// Second kernel - pub B + pub B, ); -impl Lipschitz for Convolution -where A : Norm , - B : Lipschitz { +impl Lipschitz for Convolution +where + A: Norm, + B: Lipschitz, +{ type FloatType = F; - fn lipschitz_factor(&self, m : M) -> Option { + fn lipschitz_factor(&self, m: M) -> DynResult { // For [f * g](x) = ∫ f(x-y)g(y) dy we have // [f * g](x) - [f * g](z) = ∫ [f(x-y)-f(z-y)]g(y) dy. // Hence |[f * g](x) - [f * g](z)| ≤ ∫ |f(x-y)-f(z-y)|g(y)| dy. // ≤ L|x-z| ∫ |g(y)| dy, // where L is the Lipschitz factor of f. - self.1.lipschitz_factor(m).map(|l| l * self.0.norm(L1)) + Ok(self.1.lipschitz_factor(m)? * self.0.norm(L1)) } } -impl<'b, F : Float, M, A, B, Domain> Lipschitz -for Differential<'b, Domain, Convolution> +impl<'b, F: Float, M, A, B, Domain> LipschitzDifferentiableImpl for Convolution where - Domain : Space, - A : Clone + Norm , - Convolution : DifferentiableMapping, - B : Clone + DifferentiableMapping, - for<'a> B :: Differential<'a> : Lipschitz + Domain: Space, + Self: DifferentiableImpl, + A: Norm, + Convolution: DifferentiableMapping, + B: DifferentiableMapping, + for<'a> B::Differential<'a>: Lipschitz, { type FloatType = F; - fn lipschitz_factor(&self, m : M) -> Option { + fn diff_lipschitz_factor(&self, m: M) -> DynResult { // For [f * g](x) = ∫ f(x-y)g(y) dy we have // ∇[f * g](x) - ∇[f * g](z) = ∫ [∇f(x-y)-∇f(z-y)]g(y) dy. // Hence |∇[f * g](x) - ∇[f * g](z)| ≤ ∫ |∇f(x-y)-∇f(z-y)|g(y)| dy. // ≤ L|x-z| ∫ |g(y)| dy, // where L is the Lipschitz factor of ∇f. - let base = self.base_fn(); - base.1.diff_ref().lipschitz_factor(m).map(|l| l * base.0.norm(L1)) + self.1 + .diff_ref() + .lipschitz_factor(m) + .map(|l| l * self.0.norm(L1)) } } @@ -346,78 +328,76 @@ /// The kernel typically implements [`Support`] and [`Mapping`]. /// /// Trait implementations have to be on a case-by-case basis. -#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] pub struct AutoConvolution( /// The kernel to be autoconvolved - pub A + pub A, ); -impl Lipschitz for AutoConvolution -where C : Lipschitz + Norm { +impl Lipschitz for AutoConvolution +where + C: Lipschitz + Norm, +{ type FloatType = F; - fn lipschitz_factor(&self, m : M) -> Option { + fn lipschitz_factor(&self, m: M) -> DynResult { self.0.lipschitz_factor(m).map(|l| l * self.0.norm(L1)) } } -impl<'b, F : Float, M, C, Domain> Lipschitz -for Differential<'b, Domain, AutoConvolution> +impl<'b, F: Float, M, C, Domain> LipschitzDifferentiableImpl for AutoConvolution where - Domain : Space, - C : Clone + Norm + DifferentiableMapping, - AutoConvolution : DifferentiableMapping, - for<'a> C :: Differential<'a> : Lipschitz + Domain: Space, + Self: DifferentiableImpl, + C: Norm + DifferentiableMapping, + AutoConvolution: DifferentiableMapping, + for<'a> C::Differential<'a>: Lipschitz, { type FloatType = F; - fn lipschitz_factor(&self, m : M) -> Option { - let base = self.base_fn(); - base.0.diff_ref().lipschitz_factor(m).map(|l| l * base.0.norm(L1)) + fn diff_lipschitz_factor(&self, m: M) -> DynResult { + self.0 + .diff_ref() + .lipschitz_factor(m) + .map(|l| l * self.0.norm(L1)) } } - /// Representation a multi-dimensional product of a one-dimensional kernel. /// /// For $G: ℝ → ℝ$, this is the function $F(x\_1, …, x\_n) := \prod_{i=1}^n G(x\_i)$. /// The kernel $G$ typically implements [`Support`] and [`Mapping`] -/// on [`Loc`]. Then the product implements them on [`Loc`]. -#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] +/// on [`Loc<1, F>`]. Then the product implements them on [`Loc`]. +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] #[allow(dead_code)] -struct UniformProduct( +struct UniformProduct( /// The one-dimensional kernel - G + G, ); -impl<'a, G, F : Float, const N : usize> Mapping> -for UniformProduct +impl<'a, G, F: Float, const N: usize> Mapping> for UniformProduct where - G : Mapping, Codomain = F> + G: Mapping, Codomain = F>, { type Codomain = F; #[inline] - fn apply>>(&self, x : I) -> F { - x.cow().iter().map(|&y| self.0.apply(Loc([y]))).product() + fn apply>>(&self, x: I) -> F { + x.decompose() + .iter() + .map(|&y| self.0.apply(Loc([y]))) + .product() } } - - -impl<'a, G, F : Float, const N : usize> DifferentiableImpl> -for UniformProduct +impl<'a, G, F: Float, const N: usize> DifferentiableImpl> for UniformProduct where - G : DifferentiableMapping< - Loc, - DerivativeDomain = F, - Codomain = F, - > + G: DifferentiableMapping, DerivativeDomain = F, Codomain = F>, { - type Derivative = Loc; + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, x0 : I) -> Loc { + fn differential_impl>>(&self, x0: I) -> Loc { x0.eval(|x| { let vs = x.map(|y| self.0.apply(Loc([y]))); product_differential(x, &vs, |y| self.0.differential(Loc([y]))) @@ -430,17 +410,19 @@ /// The vector `x` is the location, `vs` consists of the values `g(x_i)`, and /// `gd` calculates the derivative `g'`. #[inline] -pub(crate) fn product_differential F, const N : usize>( - x : &Loc, - vs : &Loc, - gd : G -) -> Loc { +pub(crate) fn product_differential F, const N: usize>( + x: &Loc, + vs: &Loc, + gd: G, +) -> Loc { map1_indexed(x, |i, &y| { - gd(y) * vs.iter() - .zip(0..) - .filter_map(|(v, j)| (j != i).then_some(*v)) - .product() - }).into() + gd(y) + * vs.iter() + .zip(0..) + .filter_map(|(v, j)| (j != i).then_some(*v)) + .product() + }) + .into() } /// Helper function to calulate the Lipschitz factor of $∇f$ for $f(x)=∏_{i=1}^N g(x_i)$. @@ -448,12 +430,12 @@ /// The parameter `bound` is a bound on $|g|_∞$, `lip` is a Lipschitz factor for $g$, /// `dbound` is a bound on $|∇g|_∞$, and `dlip` a Lipschitz factor for $∇g$. #[inline] -pub(crate) fn product_differential_lipschitz_factor( - bound : F, - lip : F, - dbound : F, - dlip : F -) -> F { +pub(crate) fn product_differential_lipschitz_factor( + bound: F, + lip: F, + dbound: F, + dlip: F, +) -> DynResult { // For arbitrary ψ(x) = ∏_{i=1}^n ψ_i(x_i), we have // ψ(x) - ψ(y) = ∑_i [ψ_i(x_i)-ψ_i(y_i)] ∏_{j ≠ i} ψ_j(x_j) // by a simple recursive argument. In particular, if ψ_i=g for all i, j, we have @@ -470,31 +452,33 @@ // = n [L_{∇g} M_g^{n-1} + (n-1) L_g M_g^{n-2} M_{∇g}]. // = n M_g^{n-2}[L_{∇g} M_g + (n-1) L_g M_{∇g}]. if N >= 2 { - F::cast_from(N) * bound.powi((N-2) as i32) - * (dlip * bound + F::cast_from(N-1) * lip * dbound) - } else if N==1 { - dlip + Ok(F::cast_from(N) + * bound.powi((N - 2) as i32) + * (dlip * bound + F::cast_from(N - 1) * lip * dbound)) + } else if N == 1 { + Ok(dlip) } else { - panic!("Invalid dimension") + Err(anyhow!("Invalid dimension")) } } -impl Support -for UniformProduct -where G : Support { +impl Support for UniformProduct +where + G: Support<1, F>, +{ #[inline] - fn support_hint(&self) -> Cube { - let [a] : [[F; 2]; 1] = self.0.support_hint().into(); + fn support_hint(&self) -> Cube { + let [a]: [[F; 2]; 1] = self.0.support_hint().into(); array_init(|| a.clone()).into() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { x.iter().all(|&y| self.0.in_support(&Loc([y]))) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { cube.map(|a, b| { let [h] = self.0.bisection_hint(&[[a, b]].into()); h @@ -502,9 +486,10 @@ } } -impl GlobalAnalysis> -for UniformProduct -where G : GlobalAnalysis> { +impl GlobalAnalysis> for UniformProduct +where + G: GlobalAnalysis>, +{ #[inline] fn global_analysis(&self) -> Bounds { let g = self.0.global_analysis(); @@ -512,88 +497,91 @@ } } -impl LocalAnalysis, N> -for UniformProduct -where G : LocalAnalysis, 1> { +impl LocalAnalysis, N> for UniformProduct +where + G: LocalAnalysis, 1>, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { - cube.iter_coords().map( - |&[a, b]| self.0.local_analysis(&([[a, b]].into())) - ).product() + fn local_analysis(&self, cube: &Cube) -> Bounds { + cube.iter_coords() + .map(|&[a, b]| self.0.local_analysis(&([[a, b]].into()))) + .product() } } macro_rules! product_lpnorm { ($lp:ident) => { - impl Norm - for UniformProduct - where G : Norm { + impl Norm<$lp, F> for UniformProduct + where + G: Norm<$lp, F>, + { #[inline] - fn norm(&self, lp : $lp) -> F { + fn norm(&self, lp: $lp) -> F { self.0.norm(lp).powi(N as i32) } } - } + }; } product_lpnorm!(L1); product_lpnorm!(L2); product_lpnorm!(Linfinity); - /// Trait for bounding one kernel with respect to another. /// /// The type `F` is the scalar field, and `T` another kernel to which `Self` is compared. -pub trait BoundedBy { +pub trait BoundedBy { /// Calclate a bounding factor $c$ such that the Fourier transforms $ℱ\[v\] ≤ c ℱ\[u\]$ for /// $v$ `self` and $u$ `other`. /// /// If no such factors exits, `None` is returned. - fn bounding_factor(&self, other : &T) -> Option; + fn bounding_factor(&self, other: &T) -> DynResult; } /// This [`BoundedBy`] implementation bounds $(uv) * (uv)$ by $(ψ * ψ) u$. #[replace_float_literals(F::cast_from(literal))] -impl -BoundedBy, BaseP>> -for AutoConvolution> -where F : Float, - C : Clone + PartialEq, - BaseP : Fourier + PartialOrd, // TODO: replace by BoundedBy, - >::Transformed : Bounded + Norm { - - fn bounding_factor(&self, kernel : &SupportProductFirst, BaseP>) -> Option { +impl BoundedBy, BaseP>> + for AutoConvolution> +where + F: Float, + C: PartialEq, + BaseP: Fourier + PartialOrd, // TODO: replace by BoundedBy, + >::Transformed: Bounded + Norm, +{ + fn bounding_factor( + &self, + kernel: &SupportProductFirst, BaseP>, + ) -> DynResult { let SupportProductFirst(AutoConvolution(ref cutoff2), base_spread2) = kernel; let AutoConvolution(SupportProductFirst(ref cutoff, ref base_spread)) = self; let v̂ = base_spread.fourier(); // Verify that the cut-off and ideal physical model (base spread) are the same. - if cutoff == cutoff2 - && base_spread <= base_spread2 - && v̂.bounds().lower() >= 0.0 { + if cutoff == cutoff2 && base_spread <= base_spread2 && v̂.bounds().lower() >= 0.0 { // Calculate the factor between the convolution approximation // `AutoConvolution>` of $A_*A$ and the // kernel of the seminorm. This depends on the physical model P being // `SupportProductFirst` with the kernel `K` being // a `SupportSum` involving `SupportProductFirst, BaseP>`. - Some(v̂.norm(L1)) + Ok(v̂.norm(L1)) } else { // We cannot compare - None + Err(anyhow!("Incomprable kernels")) } } } -impl BoundedBy> for A -where A : BoundedBy, - C : Bounded { - +impl BoundedBy> for A +where + A: BoundedBy, + C: Bounded, +{ #[replace_float_literals(F::cast_from(literal))] - fn bounding_factor(&self, SupportSum(ref kernel1, kernel2) : &SupportSum) -> Option { + fn bounding_factor(&self, SupportSum(ref kernel1, kernel2): &SupportSum) -> DynResult { if kernel2.bounds().lower() >= 0.0 { self.bounding_factor(kernel1) } else { - None + Err(anyhow!("Component kernel not lower-bounded by zero")) } } } @@ -603,7 +591,7 @@ /// It will attempt to place the subdivision point at $-r$ or $r$. /// If neither of these points lies within $[a, b]$, `None` is returned. #[inline] -pub(super) fn symmetric_interval_hint(r : F, a : F, b : F) -> Option { +pub(super) fn symmetric_interval_hint(r: F, a: F, b: F) -> Option { if a < -r && -r < b { Some(-r) } else if a < r && r < b { @@ -622,7 +610,7 @@ /// returned. #[replace_float_literals(F::cast_from(literal))] #[inline] -pub(super) fn symmetric_peak_hint(r : F, a : F, b : F) -> Option { +pub(super) fn symmetric_peak_hint(r: F, a: F, b: F) -> Option { let stage1 = if a < -r { if b <= -r { None @@ -648,7 +636,7 @@ // Ignore stage1 hint if either side of subdivision would be just a small fraction of the // interval match stage1 { - Some(h) if (h - a).min(b-h) >= 0.3 * r => Some(h), - _ => None + Some(h) if (h - a).min(b - h) >= 0.3 * r => Some(h), + _ => None, } } diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/gaussian.rs --- a/src/kernels/gaussian.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/gaussian.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,58 +1,51 @@ //! Implementation of the gaussian kernel. +use alg_tools::bisection_tree::{Constant, Support, Weighted}; +use alg_tools::bounds::{Bounded, Bounds, GlobalAnalysis, LocalAnalysis}; +use alg_tools::euclidean::Euclidean; +use alg_tools::loc::Loc; +use alg_tools::mapping::{ + DifferentiableImpl, Differential, Instance, LipschitzDifferentiableImpl, Mapping, +}; +use alg_tools::maputil::array_init; +use alg_tools::norms::*; +use alg_tools::sets::Cube; +use alg_tools::types::*; use float_extras::f64::erf; use numeric_literals::replace_float_literals; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::euclidean::Euclidean; -use alg_tools::norms::*; -use alg_tools::loc::Loc; -use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Constant, - Bounds, - LocalAnalysis, - GlobalAnalysis, - Weighted, - Bounded, -}; -use alg_tools::mapping::{ - Mapping, - Instance, - Differential, - DifferentiableImpl, -}; -use alg_tools::maputil::array_init; -use crate::types::*; +use super::ball_indicator::CubeIndicator; +use super::base::*; use crate::fourier::Fourier; -use super::base::*; -use super::ball_indicator::CubeIndicator; +use crate::types::*; /// Storage presentation of the the anisotropic gaussian kernel of `variance` $σ^2$. /// /// This is the function $f(x) = C e^{-\\|x\\|\_2^2/(2σ^2)}$ for $x ∈ ℝ^N$ /// with $C=1/(2πσ^2)^{N/2}$. -#[derive(Copy,Clone,Debug,Serialize,Eq)] -pub struct Gaussian { +#[derive(Copy, Clone, Debug, Serialize, Eq)] +pub struct Gaussian { /// The variance $σ^2$. - pub variance : S, + pub variance: S, } -impl PartialEq> for Gaussian -where S1 : Constant, - S2 : Constant { - fn eq(&self, other : &Gaussian) -> bool { +impl PartialEq> for Gaussian +where + S1: Constant, + S2: Constant, +{ + fn eq(&self, other: &Gaussian) -> bool { self.variance.value() == other.variance.value() } } -impl PartialOrd> for Gaussian -where S1 : Constant, - S2 : Constant { - - fn partial_cmp(&self, other : &Gaussian) -> Option { +impl PartialOrd> for Gaussian +where + S1: Constant, + S2: Constant, +{ + fn partial_cmp(&self, other: &Gaussian) -> Option { // A gaussian is ≤ another gaussian if the Fourier transforms satisfy the // corresponding inequality. That in turns holds if and only if the variances // satisfy the opposite inequality. @@ -62,18 +55,17 @@ } } - #[replace_float_literals(S::Type::cast_from(literal))] -impl<'a, S, const N : usize> Mapping> for Gaussian +impl<'a, S, const N: usize> Mapping> for Gaussian where - S : Constant + S: Constant, { type Codomain = S::Type; // This is not normalised to neither to have value 1 at zero or integral 1 // (unless the cut-off ε=0). #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { let d_squared = x.eval(|x| x.norm2_squared()); let σ2 = self.variance.value(); let scale = self.scale(); @@ -82,19 +74,20 @@ } #[replace_float_literals(S::Type::cast_from(literal))] -impl<'a, S, const N : usize> DifferentiableImpl> for Gaussian -where S : Constant { - type Derivative = Loc; +impl<'a, S, const N: usize> DifferentiableImpl> for Gaussian +where + S: Constant, +{ + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, x0 : I) -> Self::Derivative { - let x = x0.cow(); + fn differential_impl>>(&self, x0: I) -> Self::Derivative { + let x = x0.decompose(); let f = -self.apply(&*x) / self.variance.value(); *x * f } } - // To calculate the the Lipschitz factors, we consider // f(t) = e^{-t²/2} // f'(t) = -t f(t) which has max at t=1 by f''(t)=0 @@ -117,25 +110,26 @@ // Hence the Lipschitz factor of ∇g is (C/σ²)f''(√3) = (C/σ²)2e^{-3/2}. #[replace_float_literals(S::Type::cast_from(literal))] -impl Lipschitz for Gaussian -where S : Constant { +impl Lipschitz for Gaussian +where + S: Constant, +{ type FloatType = S::Type; - fn lipschitz_factor(&self, L2 : L2) -> Option { - Some((-0.5).exp() / (self.scale() * self.variance.value().sqrt())) + fn lipschitz_factor(&self, L2: L2) -> DynResult { + Ok((-0.5).exp() / (self.scale() * self.variance.value().sqrt())) } } - #[replace_float_literals(S::Type::cast_from(literal))] -impl<'a, S : Constant, const N : usize> Lipschitz -for Differential<'a, Loc, Gaussian> { +impl<'a, S: Constant, const N: usize> LipschitzDifferentiableImpl, L2> + for Gaussian +{ type FloatType = S::Type; - - fn lipschitz_factor(&self, _l2 : L2) -> Option { - let g = self.base_fn(); - let σ2 = g.variance.value(); - let scale = g.scale(); - Some(2.0*(-3.0/2.0).exp()/(σ2*scale)) + + fn diff_lipschitz_factor(&self, _l2: L2) -> DynResult { + let σ2 = self.variance.value(); + let scale = self.scale(); + Ok(2.0 * (-3.0 / 2.0).exp() / (σ2 * scale)) } } @@ -146,65 +140,73 @@ // factors of the undifferentiated function, given how the latter is calculed above. #[replace_float_literals(S::Type::cast_from(literal))] -impl<'b, S : Constant, const N : usize> NormBounded -for Differential<'b, Loc, Gaussian> { +impl<'b, S: Constant, const N: usize> NormBounded + for Differential<'b, Loc, Gaussian> +{ type FloatType = S::Type; - - fn norm_bound(&self, _l2 : L2) -> S::Type { + + fn norm_bound(&self, _l2: L2) -> S::Type { self.base_fn().lipschitz_factor(L2).unwrap() } } #[replace_float_literals(S::Type::cast_from(literal))] -impl<'b, 'a, S : Constant, const N : usize> NormBounded -for Differential<'b, Loc, &'a Gaussian> { +impl<'b, 'a, S: Constant, const N: usize> NormBounded + for Differential<'b, Loc, &'a Gaussian> +{ type FloatType = S::Type; - - fn norm_bound(&self, _l2 : L2) -> S::Type { + + fn norm_bound(&self, _l2: L2) -> S::Type { self.base_fn().lipschitz_factor(L2).unwrap() } } - #[replace_float_literals(S::Type::cast_from(literal))] -impl<'a, S, const N : usize> Gaussian -where S : Constant { - +impl<'a, S, const N: usize> Gaussian +where + S: Constant, +{ /// Returns the (reciprocal) scaling constant $1/C=(2πσ^2)^{N/2}$. #[inline] pub fn scale(&self) -> S::Type { let π = S::Type::PI; let σ2 = self.variance.value(); - (2.0*π*σ2).powi(N as i32).sqrt() + (2.0 * π * σ2).powi(N as i32).sqrt() } } -impl<'a, S, const N : usize> Support for Gaussian -where S : Constant { +impl<'a, S, const N: usize> Support for Gaussian +where + S: Constant, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { array_init(|| [S::Type::NEG_INFINITY, S::Type::INFINITY]).into() } #[inline] - fn in_support(&self, _x : &Loc) -> bool { + fn in_support(&self, _x: &Loc) -> bool { true } } #[replace_float_literals(S::Type::cast_from(literal))] -impl GlobalAnalysis> for Gaussian -where S : Constant { +impl GlobalAnalysis> for Gaussian +where + S: Constant, +{ #[inline] fn global_analysis(&self) -> Bounds { - Bounds(0.0, 1.0/self.scale()) + Bounds(0.0, 1.0 / self.scale()) } } -impl LocalAnalysis, N> for Gaussian -where S : Constant { +impl LocalAnalysis, N> for Gaussian +where + S: Constant, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the 2-norm is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); @@ -213,68 +215,63 @@ } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Norm -for Gaussian { +impl<'a, C: Constant, const N: usize> Norm for Gaussian { #[inline] - fn norm(&self, _ : L1) -> C::Type { + fn norm(&self, _: L1) -> C::Type { 1.0 } } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Norm -for Gaussian { +impl<'a, C: Constant, const N: usize> Norm for Gaussian { #[inline] - fn norm(&self, _ : Linfinity) -> C::Type { + fn norm(&self, _: Linfinity) -> C::Type { self.bounds().upper() } } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Fourier -for Gaussian { - type Domain = Loc; +impl<'a, C: Constant, const N: usize> Fourier for Gaussian { + type Domain = Loc; type Transformed = Weighted, C::Type>; #[inline] fn fourier(&self) -> Self::Transformed { let π = C::Type::PI; let σ2 = self.variance.value(); - let g = Gaussian { variance : 1.0 / (4.0*π*π*σ2) }; + let g = Gaussian { variance: 1.0 / (4.0 * π * π * σ2) }; g.weigh(g.scale()) } } /// Representation of the “cut” gaussian $f χ\_{[-a, a]^n}$ /// where $a>0$ and $f$ is a gaussian kernel on $ℝ^n$. -pub type BasicCutGaussian = SupportProductFirst, - Gaussian>; - +pub type BasicCutGaussian = + SupportProductFirst, Gaussian>; /// This implements $g := χ\_{[-b, b]^n} \* (f χ\_{[-a, a]^n})$ where $a,b>0$ and $f$ is /// a gaussian kernel on $ℝ^n$. For an expression for $g$, see Lemma 3.9 in the manuscript. #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, R, C, S, const N : usize> Mapping> -for Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { - +impl<'a, F: Float, R, C, S, const N: usize> Mapping> + for Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ type Codomain = F; #[inline] - fn apply>>(&self, y : I) -> F { - let Convolution(ref ind, - SupportProductFirst(ref cut, - ref gaussian)) = self; + fn apply>>(&self, y: I) -> F { + let Convolution(ref ind, SupportProductFirst(ref cut, ref gaussian)) = self; let a = cut.r.value(); let b = ind.r.value(); let σ = gaussian.variance.value().sqrt(); let t = F::SQRT_2 * σ; let c = 0.5; // 1/(σ√(2π) * σ√(π/2) = 1/2 - + // This is just a product of one-dimensional versions - y.cow().product_map(|x| { + y.decompose().product_map(|x| { let c1 = -(a.min(b + x)); //(-a).max(-x-b); let c2 = a.min(b - x); if c1 >= c2 { @@ -293,28 +290,27 @@ /// and $f$ is a gaussian kernel on $ℝ^n$. For an expression for the value of $g$, from which the /// derivative readily arises (at points of differentiability), see Lemma 3.9 in the manuscript. #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, R, C, S, const N : usize> DifferentiableImpl> -for Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { - - type Derivative = Loc; +impl<'a, F: Float, R, C, S, const N: usize> DifferentiableImpl> + for Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ + type Derivative = Loc; /// Although implemented, this function is not differentiable. #[inline] - fn differential_impl>>(&self, y0 : I) -> Loc { - let Convolution(ref ind, - SupportProductFirst(ref cut, - ref gaussian)) = self; - let y = y0.cow(); + fn differential_impl>>(&self, y0: I) -> Loc { + let Convolution(ref ind, SupportProductFirst(ref cut, ref gaussian)) = self; + let y = y0.decompose(); let a = cut.r.value(); let b = ind.r.value(); let σ = gaussian.variance.value().sqrt(); let t = F::SQRT_2 * σ; let c = 0.5; // 1/(σ√(2π) * σ√(π/2) = 1/2 let c_mul_erf_scale_div_t = c * F::FRAC_2_SQRT_PI / t; - + // Calculate the values for all component functions of the // product. This is just the loop from apply above. let unscaled_vs = y.map(|x| { @@ -340,12 +336,12 @@ // from the chain rule (the minus comes from inside c_1 or c_2, and changes the // order of de2 and de1 in the final calculation). let de1 = if b + x < a { - (-((b+x)/t).powi(2)).exp() + (-((b + x) / t).powi(2)).exp() } else { 0.0 }; let de2 = if b - x < a { - (-((b-x)/t).powi(2)).exp() + (-((b - x) / t).powi(2)).exp() } else { 0.0 }; @@ -355,16 +351,17 @@ } } - #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, R, C, S, const N : usize> Lipschitz -for Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { +impl<'a, F: Float, R, C, S, const N: usize> Lipschitz + for Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ type FloatType = F; - fn lipschitz_factor(&self, L1 : L1) -> Option { + fn lipschitz_factor(&self, L1: L1) -> DynResult { // To get the product Lipschitz factor, we note that for any ψ_i, we have // ∏_{i=1}^N φ_i(x_i) - ∏_{i=1}^N φ_i(y_i) // = [φ_1(x_1)-φ_1(y_1)] ∏_{i=2}^N φ_i(x_i) @@ -398,9 +395,7 @@ // θ * ψ(x) = θ * ψ(y). If only y is in the range [-(a+b), a+b], we can replace // x by -(a+b) or (a+b), either of which is closer to y and still θ * ψ(x)=0. // Thus same calculations as above work for the Lipschitz factor. - let Convolution(ref ind, - SupportProductFirst(ref cut, - ref gaussian)) = self; + let Convolution(ref ind, SupportProductFirst(ref cut, ref gaussian)) = self; let a = cut.r.value(); let b = ind.r.value(); let σ = gaussian.variance.value().sqrt(); @@ -408,7 +403,7 @@ let t = F::SQRT_2 * σ; let l1d = F::SQRT_2 / (π.sqrt() * σ); let e0 = F::cast_from(erf((a.min(b) / t).as_())); - Some(l1d * e0.powi(N as i32-1)) + Ok(l1d * e0.powi(N as i32 - 1)) } } @@ -426,39 +421,40 @@ } */ -impl -Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { - +impl Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ #[inline] fn get_r(&self) -> F { - let Convolution(ref ind, - SupportProductFirst(ref cut, ..)) = self; + let Convolution(ref ind, SupportProductFirst(ref cut, ..)) = self; ind.r.value() + cut.r.value() } } -impl Support -for Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { +impl Support + for Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let r = self.get_r(); array_init(|| [-r, r]).into() } #[inline] - fn in_support(&self, y : &Loc) -> bool { + fn in_support(&self, y: &Loc) -> bool { let r = self.get_r(); y.iter().all(|x| x.abs() <= r) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { let r = self.get_r(); // From c1 = -a.min(b + x) and c2 = a.min(b - x) with c_1 < c_2, // solve bounds for x. that is 0 ≤ a.min(b + x) + a.min(b - x). @@ -470,28 +466,31 @@ } } -impl GlobalAnalysis> -for Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { +impl GlobalAnalysis> + for Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ #[inline] fn global_analysis(&self) -> Bounds { Bounds(F::ZERO, self.apply(Loc::ORIGIN)) } } -impl LocalAnalysis, N> -for Convolution, BasicCutGaussian> -where R : Constant, - C : Constant, - S : Constant { +impl LocalAnalysis, N> + for Convolution, BasicCutGaussian> +where + R: Constant, + C: Constant, + S: Constant, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the absolute value is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); Bounds(lower, upper) } } - diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/hat.rs --- a/src/kernels/hat.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/hat.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,60 +1,55 @@ //! Implementation of the hat function +use crate::types::Lipschitz; +use alg_tools::bisection_tree::{Constant, Support}; +use alg_tools::bounds::{Bounded, Bounds, GlobalAnalysis, LocalAnalysis}; +use alg_tools::error::DynResult; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Instance, Mapping}; +use alg_tools::maputil::array_init; +use alg_tools::norms::*; +use alg_tools::sets::Cube; +use alg_tools::types::*; use numeric_literals::replace_float_literals; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::norms::*; -use alg_tools::loc::Loc; -use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Constant, - Bounds, - LocalAnalysis, - GlobalAnalysis, - Bounded, -}; -use alg_tools::mapping::{Mapping, Instance}; -use alg_tools::maputil::array_init; -use crate::types::Lipschitz; /// Representation of the hat function $f(x)=1-\\|x\\|\_1/ε$ of `width` $ε$ on $ℝ^N$. -#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] -pub struct Hat { +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] +pub struct Hat { /// The parameter $ε>0$. - pub width : C, + pub width: C, } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Mapping> for Hat { +impl<'a, C: Constant, const N: usize> Mapping> for Hat { type Codomain = C::Type; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { let ε = self.width.value(); - 0.0.max(1.0-x.cow().norm(L1)/ε) + 0.0.max(1.0 - x.decompose().norm(L1) / ε) } } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Support for Hat { +impl<'a, C: Constant, const N: usize> Support for Hat { #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let ε = self.width.value(); array_init(|| [-ε, ε]).into() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { x.norm(L1) < self.width.value() } - + /*fn fully_in_support(&self, _cube : &Cube) -> bool { todo!("Not implemented, but not used at the moment") }*/ #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { let ε = self.width.value(); cube.map(|a, b| { if a < 1.0 { @@ -62,24 +57,29 @@ Some(1.0) } else { if a < -ε { - if b > -ε { Some(-ε) } else { None } + if b > -ε { + Some(-ε) + } else { + None + } } else { None } } } else { - if b > ε { Some(ε) } else { None } + if b > ε { + Some(ε) + } else { + None + } } }); todo!("also diagonals") } } - #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> -GlobalAnalysis> -for Hat { +impl<'a, C: Constant, const N: usize> GlobalAnalysis> for Hat { #[inline] fn global_analysis(&self) -> Bounds { Bounds(0.0, 1.0) @@ -87,30 +87,27 @@ } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Lipschitz for Hat { +impl<'a, C: Constant, const N: usize> Lipschitz for Hat { type FloatType = C::Type; - fn lipschitz_factor(&self, _l1 : L1) -> Option { - Some(1.0/self.width.value()) + fn lipschitz_factor(&self, _l1: L1) -> DynResult { + Ok(1.0 / self.width.value()) } } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Lipschitz for Hat { +impl<'a, C: Constant, const N: usize> Lipschitz for Hat { type FloatType = C::Type; - fn lipschitz_factor(&self, _l2 : L2) -> Option { - self.lipschitz_factor(L1).map(|l1| - >>::from_norm(&L2, l1, L1) - ) + fn lipschitz_factor(&self, _l2: L2) -> DynResult { + self.lipschitz_factor(L1) + .map(|l1| >>::from_norm(&L2, l1, L1)) } } -impl<'a, C : Constant, const N : usize> -LocalAnalysis, N> -for Hat { +impl<'a, C: Constant, const N: usize> LocalAnalysis, N> for Hat { #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the 1-norm is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); @@ -119,12 +116,9 @@ } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> -Norm -for Hat { +impl<'a, C: Constant, const N: usize> Norm for Hat { #[inline] - fn norm(&self, _ : Linfinity) -> C::Type { + fn norm(&self, _: Linfinity) -> C::Type { self.bounds().upper() } } - diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/hat_convolution.rs --- a/src/kernels/hat_convolution.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/hat_convolution.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,30 +1,22 @@ //! Implementation of the convolution of two hat functions, //! and its convolution with a [`CubeIndicator`]. + +use alg_tools::bisection_tree::{Constant, Support}; +use alg_tools::bounds::{Bounded, Bounds, GlobalAnalysis, LocalAnalysis}; +use alg_tools::error::DynResult; +use alg_tools::loc::Loc; +use alg_tools::mapping::{DifferentiableImpl, Instance, LipschitzDifferentiableImpl, Mapping}; +use alg_tools::maputil::array_init; +use alg_tools::norms::*; +use alg_tools::sets::Cube; +use alg_tools::types::*; +use anyhow::anyhow; use numeric_literals::replace_float_literals; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::norms::*; -use alg_tools::loc::Loc; -use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Constant, - Bounds, - LocalAnalysis, - GlobalAnalysis, - Bounded, -}; -use alg_tools::mapping::{ - Mapping, - Instance, - DifferentiableImpl, - Differential, -}; -use alg_tools::maputil::array_init; -use crate::types::Lipschitz; +use super::ball_indicator::CubeIndicator; use super::base::*; -use super::ball_indicator::CubeIndicator; +use crate::types::Lipschitz; /// Hat convolution kernel. /// @@ -69,21 +61,26 @@ // $$ // [∇f(x\_1, …, x\_n)]_j = \frac{4}{σ} (h\*h)'(x\_j/σ) \prod\_{j ≠ i} \frac{4}{σ} (h\*h)(x\_i/σ) // $$ -#[derive(Copy,Clone,Debug,Serialize,Eq)] -pub struct HatConv { +#[derive(Copy, Clone, Debug, Serialize, Eq)] +pub struct HatConv { /// The parameter $σ$ of the kernel. - pub radius : S, + pub radius: S, } -impl PartialEq> for HatConv -where S1 : Constant, - S2 : Constant { - fn eq(&self, other : &HatConv) -> bool { +impl PartialEq> for HatConv +where + S1: Constant, + S2: Constant, +{ + fn eq(&self, other: &HatConv) -> bool { self.radius.value() == other.radius.value() } } -impl<'a, S, const N : usize> HatConv where S : Constant { +impl<'a, S, const N: usize> HatConv +where + S: Constant, +{ /// Returns the $σ$ parameter of the kernel. #[inline] pub fn radius(&self) -> S::Type { @@ -91,25 +88,27 @@ } } -impl<'a, S, const N : usize> Mapping> for HatConv -where S : Constant { +impl<'a, S, const N: usize> Mapping> for HatConv +where + S: Constant, +{ type Codomain = S::Type; #[inline] - fn apply>>(&self, y : I) -> Self::Codomain { + fn apply>>(&self, y: I) -> Self::Codomain { let σ = self.radius(); - y.cow().product_map(|x| { - self.value_1d_σ1(x / σ) / σ - }) + y.decompose().product_map(|x| self.value_1d_σ1(x / σ) / σ) } } #[replace_float_literals(S::Type::cast_from(literal))] -impl Lipschitz for HatConv -where S : Constant { +impl Lipschitz for HatConv +where + S: Constant, +{ type FloatType = S::Type; #[inline] - fn lipschitz_factor(&self, L1 : L1) -> Option { + fn lipschitz_factor(&self, L1: L1) -> DynResult { // For any ψ_i, we have // ∏_{i=1}^N ψ_i(x_i) - ∏_{i=1}^N ψ_i(y_i) // = [ψ_1(x_1)-ψ_1(y_1)] ∏_{i=2}^N ψ_i(x_i) @@ -119,86 +118,91 @@ // |∏_{i=1}^N ψ_i(x_i) - ∏_{i=1}^N ψ_i(y_i)| // ≤ ∑_{j=1}^N |ψ_j(x_j)-ψ_j(y_j)| ∏_{j ≠ i} \max_j |ψ_j| let σ = self.radius(); - let l1d = self.lipschitz_1d_σ1() / (σ*σ); + let l1d = self.lipschitz_1d_σ1() / (σ * σ); let m1d = self.value_1d_σ1(0.0) / σ; - Some(l1d * m1d.powi(N as i32 - 1)) + Ok(l1d * m1d.powi(N as i32 - 1)) } } -impl Lipschitz for HatConv -where S : Constant { +impl Lipschitz for HatConv +where + S: Constant, +{ type FloatType = S::Type; #[inline] - fn lipschitz_factor(&self, L2 : L2) -> Option { - self.lipschitz_factor(L1).map(|l1| l1 * ::cast_from(N).sqrt()) + fn lipschitz_factor(&self, L2: L2) -> DynResult { + self.lipschitz_factor(L1) + .map(|l1| l1 * ::cast_from(N).sqrt()) } } - -impl<'a, S, const N : usize> DifferentiableImpl> for HatConv -where S : Constant { - type Derivative = Loc; +impl<'a, S, const N: usize> DifferentiableImpl> for HatConv +where + S: Constant, +{ + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, y0 : I) -> Self::Derivative { - let y = y0.cow(); + fn differential_impl>>(&self, y0: I) -> Self::Derivative { + let y = y0.decompose(); let σ = self.radius(); let σ2 = σ * σ; - let vs = y.map(|x| { - self.value_1d_σ1(x / σ) / σ - }); - product_differential(&*y, &vs, |x| { - self.diff_1d_σ1(x / σ) / σ2 - }) + let vs = y.map(|x| self.value_1d_σ1(x / σ) / σ); + product_differential(&*y, &vs, |x| self.diff_1d_σ1(x / σ) / σ2) } } - #[replace_float_literals(S::Type::cast_from(literal))] -impl<'a, F : Float, S, const N : usize> Lipschitz -for Differential<'a, Loc, HatConv> -where S : Constant { +impl<'a, F: Float, S, const N: usize> LipschitzDifferentiableImpl, L2> + for HatConv +where + S: Constant, +{ type FloatType = F; #[inline] - fn lipschitz_factor(&self, _l2 : L2) -> Option { - let h = self.base_fn(); - let σ = h.radius(); - Some(product_differential_lipschitz_factor::( - h.value_1d_σ1(0.0) / σ, - h.lipschitz_1d_σ1() / (σ*σ), - h.maxabsdiff_1d_σ1() / (σ*σ), - h.lipschitz_diff_1d_σ1() / (σ*σ), - )) + fn diff_lipschitz_factor(&self, _l2: L2) -> DynResult { + let σ = self.radius(); + product_differential_lipschitz_factor::( + self.value_1d_σ1(0.0) / σ, + self.lipschitz_1d_σ1() / (σ * σ), + self.maxabsdiff_1d_σ1() / (σ * σ), + self.lipschitz_diff_1d_σ1() / (σ * σ), + ) } } - #[replace_float_literals(S::Type::cast_from(literal))] -impl<'a, F : Float, S, const N : usize> HatConv -where S : Constant { +impl<'a, F: Float, S, const N: usize> HatConv +where + S: Constant, +{ /// Computes the value of the kernel for $n=1$ with $σ=1$. #[inline] - fn value_1d_σ1(&self, x : F) -> F { + fn value_1d_σ1(&self, x: F) -> F { let y = x.abs(); if y >= 1.0 { 0.0 } else if y > 0.5 { - - (8.0/3.0) * (y - 1.0).powi(3) - } else /* 0 ≤ y ≤ 0.5 */ { - (4.0/3.0) + 8.0 * y * y * (y - 1.0) + -(8.0 / 3.0) * (y - 1.0).powi(3) + } else + /* 0 ≤ y ≤ 0.5 */ + { + (4.0 / 3.0) + 8.0 * y * y * (y - 1.0) } } /// Computes the differential of the kernel for $n=1$ with $σ=1$. #[inline] - fn diff_1d_σ1(&self, x : F) -> F { + fn diff_1d_σ1(&self, x: F) -> F { let y = x.abs(); if y >= 1.0 { 0.0 } else if y > 0.5 { - - 8.0 * (y - 1.0).powi(2) - } else /* 0 ≤ y ≤ 0.5 */ { + -8.0 * (y - 1.0).powi(2) + } else + /* 0 ≤ y ≤ 0.5 */ + { (24.0 * y - 16.0) * y } } @@ -220,13 +224,15 @@ /// Computes the second differential of the kernel for $n=1$ with $σ=1$. #[inline] #[allow(dead_code)] - fn diff2_1d_σ1(&self, x : F) -> F { + fn diff2_1d_σ1(&self, x: F) -> F { let y = x.abs(); if y >= 1.0 { 0.0 } else if y > 0.5 { - - 16.0 * (y - 1.0) - } else /* 0 ≤ y ≤ 0.5 */ { + -16.0 * (y - 1.0) + } else + /* 0 ≤ y ≤ 0.5 */ + { 48.0 * y - 16.0 } } @@ -239,40 +245,46 @@ } } -impl<'a, S, const N : usize> Support for HatConv -where S : Constant { +impl<'a, S, const N: usize> Support for HatConv +where + S: Constant, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let σ = self.radius(); array_init(|| [-σ, σ]).into() } #[inline] - fn in_support(&self, y : &Loc) -> bool { + fn in_support(&self, y: &Loc) -> bool { let σ = self.radius(); y.iter().all(|x| x.abs() <= σ) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { let σ = self.radius(); cube.map(|c, d| symmetric_peak_hint(σ, c, d)) } } #[replace_float_literals(S::Type::cast_from(literal))] -impl GlobalAnalysis> for HatConv -where S : Constant { +impl GlobalAnalysis> for HatConv +where + S: Constant, +{ #[inline] fn global_analysis(&self) -> Bounds { Bounds(0.0, self.apply(Loc::ORIGIN)) } } -impl LocalAnalysis, N> for HatConv -where S : Constant { +impl LocalAnalysis, N> for HatConv +where + S: Constant, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the 2-norm is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); @@ -281,39 +293,38 @@ } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Norm -for HatConv { +impl<'a, C: Constant, const N: usize> Norm for HatConv { #[inline] - fn norm(&self, _ : L1) -> C::Type { + fn norm(&self, _: L1) -> C::Type { 1.0 } } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Norm -for HatConv { +impl<'a, C: Constant, const N: usize> Norm for HatConv { #[inline] - fn norm(&self, _ : Linfinity) -> C::Type { + fn norm(&self, _: Linfinity) -> C::Type { self.bounds().upper() } } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, R, C, const N : usize> Mapping> -for Convolution, HatConv> -where R : Constant, - C : Constant { - +impl<'a, F: Float, R, C, const N: usize> Mapping> + for Convolution, HatConv> +where + R: Constant, + C: Constant, +{ type Codomain = F; #[inline] - fn apply>>(&self, y : I) -> F { + fn apply>>(&self, y: I) -> F { let Convolution(ref ind, ref hatconv) = self; let β = ind.r.value(); let σ = hatconv.radius(); // This is just a product of one-dimensional versions - y.cow().product_map(|x| { + y.decompose().product_map(|x| { // With $u_σ(x) = u_1(x/σ)/σ$ the normalised hat convolution // we have // $$ @@ -329,37 +340,32 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, R, C, const N : usize> DifferentiableImpl> -for Convolution, HatConv> -where R : Constant, - C : Constant { - - type Derivative = Loc; +impl<'a, F: Float, R, C, const N: usize> DifferentiableImpl> + for Convolution, HatConv> +where + R: Constant, + C: Constant, +{ + type Derivative = Loc; #[inline] - fn differential_impl>>(&self, y0 : I) -> Loc { - let y = y0.cow(); + fn differential_impl>>(&self, y0: I) -> Loc { + let y = y0.decompose(); let Convolution(ref ind, ref hatconv) = self; let β = ind.r.value(); let σ = hatconv.radius(); let σ2 = σ * σ; - let vs = y.map(|x| { - self.value_1d_σ1(x / σ, β / σ) - }); - product_differential(&*y, &vs, |x| { - self.diff_1d_σ1(x / σ, β / σ) / σ2 - }) + let vs = y.map(|x| self.value_1d_σ1(x / σ, β / σ)); + product_differential(&*y, &vs, |x| self.diff_1d_σ1(x / σ, β / σ) / σ2) } } - /// Integrate $f$, whose support is $[c, d]$, on $[a, b]$. /// If $b > d$, add $g()$ to the result. #[inline] #[replace_float_literals(F::cast_from(literal))] -fn i(a : F, b : F, c : F, d : F, f : impl Fn(F) -> F, - g : impl Fn() -> F) -> F { +fn i(a: F, b: F, c: F, d: F, f: impl Fn(F) -> F, g: impl Fn() -> F) -> F { if b < c { 0.0 } else if b <= d { @@ -368,7 +374,9 @@ } else { f(b) - f(a) } - } else /* b > d */ { + } else + /* b > d */ + { g() + if a <= c { f(d) - f(c) } else if a < d { @@ -380,84 +388,130 @@ } #[replace_float_literals(F::cast_from(literal))] -impl Convolution, HatConv> -where R : Constant, - C : Constant { - +impl Convolution, HatConv> +where + R: Constant, + C: Constant, +{ /// Calculates the value of the 1D hat convolution further convolved by a interval indicator. /// As both functions are piecewise polynomials, this is implemented by explicit integral over /// all subintervals of polynomiality of the cube indicator, using easily formed /// antiderivatives. #[inline] - pub fn value_1d_σ1(&self, x : F, β : F) -> F { + pub fn value_1d_σ1(&self, x: F, β: F) -> F { // The integration interval let a = x - β; let b = x + β; #[inline] - fn pow4(x : F) -> F { + fn pow4(x: F) -> F { let y = x * x; y * y } - + // Observe the factor 1/6 at the front from the antiderivatives below. // The factor 4 is from normalisation of the original function. - (4.0/6.0) * i(a, b, -1.0, -0.5, + (4.0 / 6.0) + * i( + a, + b, + -1.0, + -0.5, // (2/3) (y+1)^3 on -1 < y ≤ -1/2 // The antiderivative is (2/12)(y+1)^4 = (1/6)(y+1)^4 - |y| pow4(y+1.0), - || i(a, b, -0.5, 0.0, - // -2 y^3 - 2 y^2 + 1/3 on -1/2 < y ≤ 0 - // The antiderivative is -1/2 y^4 - 2/3 y^3 + 1/3 y - |y| y*(-y*y*(y*3.0 + 4.0) + 2.0), - || i(a, b, 0.0, 0.5, - // 2 y^3 - 2 y^2 + 1/3 on 0 < y < 1/2 - // The antiderivative is 1/2 y^4 - 2/3 y^3 + 1/3 y - |y| y*(y*y*(y*3.0 - 4.0) + 2.0), - || i(a, b, 0.5, 1.0, - // -(2/3) (y-1)^3 on 1/2 < y ≤ 1 - // The antiderivative is -(2/12)(y-1)^4 = -(1/6)(y-1)^4 - |y| -pow4(y-1.0), - || 0.0 + |y| pow4(y + 1.0), + || { + i( + a, + b, + -0.5, + 0.0, + // -2 y^3 - 2 y^2 + 1/3 on -1/2 < y ≤ 0 + // The antiderivative is -1/2 y^4 - 2/3 y^3 + 1/3 y + |y| y * (-y * y * (y * 3.0 + 4.0) + 2.0), + || { + i( + a, + b, + 0.0, + 0.5, + // 2 y^3 - 2 y^2 + 1/3 on 0 < y < 1/2 + // The antiderivative is 1/2 y^4 - 2/3 y^3 + 1/3 y + |y| y * (y * y * (y * 3.0 - 4.0) + 2.0), + || { + i( + a, + b, + 0.5, + 1.0, + // -(2/3) (y-1)^3 on 1/2 < y ≤ 1 + // The antiderivative is -(2/12)(y-1)^4 = -(1/6)(y-1)^4 + |y| -pow4(y - 1.0), + || 0.0, + ) + }, ) + }, ) - ) - ) + }, + ) } /// Calculates the derivative of the 1D hat convolution further convolved by a interval /// indicator. The implementation is similar to [`Self::value_1d_σ1`], using the fact that /// $(θ * ψ)' = θ * ψ'$. #[inline] - pub fn diff_1d_σ1(&self, x : F, β : F) -> F { + pub fn diff_1d_σ1(&self, x: F, β: F) -> F { // The integration interval let a = x - β; let b = x + β; // The factor 4 is from normalisation of the original function. - 4.0 * i(a, b, -1.0, -0.5, - // (2/3) (y+1)^3 on -1 < y ≤ -1/2 - |y| (2.0/3.0) * (y + 1.0).powi(3), - || i(a, b, -0.5, 0.0, + 4.0 * i( + a, + b, + -1.0, + -0.5, + // (2/3) (y+1)^3 on -1 < y ≤ -1/2 + |y| (2.0 / 3.0) * (y + 1.0).powi(3), + || { + i( + a, + b, + -0.5, + 0.0, // -2 y^3 - 2 y^2 + 1/3 on -1/2 < y ≤ 0 - |y| -2.0*(y + 1.0) * y * y + (1.0/3.0), - || i(a, b, 0.0, 0.5, + |y| -2.0 * (y + 1.0) * y * y + (1.0 / 3.0), + || { + i( + a, + b, + 0.0, + 0.5, // 2 y^3 - 2 y^2 + 1/3 on 0 < y < 1/2 - |y| 2.0*(y - 1.0) * y * y + (1.0/3.0), - || i(a, b, 0.5, 1.0, - // -(2/3) (y-1)^3 on 1/2 < y ≤ 1 - |y| -(2.0/3.0) * (y - 1.0).powi(3), - || 0.0 - ) - ) + |y| 2.0 * (y - 1.0) * y * y + (1.0 / 3.0), + || { + i( + a, + b, + 0.5, + 1.0, + // -(2/3) (y-1)^3 on 1/2 < y ≤ 1 + |y| -(2.0 / 3.0) * (y - 1.0).powi(3), + || 0.0, + ) + }, + ) + }, ) + }, ) } } /* impl<'a, F : Float, R, C, const N : usize> Lipschitz -for Differential, Convolution, HatConv>> +for Differential, Convolution, HatConv>> where R : Constant, C : Constant { @@ -471,11 +525,11 @@ } */ -impl -Convolution, HatConv> -where R : Constant, - C : Constant { - +impl Convolution, HatConv> +where + R: Constant, + C: Constant, +{ #[inline] fn get_r(&self) -> F { let Convolution(ref ind, ref hatconv) = self; @@ -483,25 +537,26 @@ } } -impl Support -for Convolution, HatConv> -where R : Constant, - C : Constant { - +impl Support + for Convolution, HatConv> +where + R: Constant, + C: Constant, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let r = self.get_r(); array_init(|| [-r, r]).into() } #[inline] - fn in_support(&self, y : &Loc) -> bool { + fn in_support(&self, y: &Loc) -> bool { let r = self.get_r(); y.iter().all(|x| x.abs() <= r) } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { // It is not difficult to verify that [`HatConv`] is C^2. // Therefore, so is [`Convolution, HatConv>`] so that a finer // subdivision for the hint than this is not particularly useful. @@ -510,22 +565,26 @@ } } -impl GlobalAnalysis> -for Convolution, HatConv> -where R : Constant, - C : Constant { +impl GlobalAnalysis> + for Convolution, HatConv> +where + R: Constant, + C: Constant, +{ #[inline] fn global_analysis(&self) -> Bounds { Bounds(F::ZERO, self.apply(Loc::ORIGIN)) } } -impl LocalAnalysis, N> -for Convolution, HatConv> -where R : Constant, - C : Constant { +impl LocalAnalysis, N> + for Convolution, HatConv> +where + R: Constant, + C: Constant, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the absolute value is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); @@ -542,7 +601,6 @@ } } - /// This [`BoundedBy`] implementation bounds $u * u$ by $(ψ * ψ) u$ for $u$ a hat convolution and /// $ψ = χ_{[-a,a]^N}$ for some $a>0$. /// @@ -550,17 +608,18 @@ /// where we take $ψ = χ_{[-a,a]^N}$ and $χ = χ_{[-σ,σ]^N}$ for $σ$ the width of the hat /// convolution. #[replace_float_literals(F::cast_from(literal))] -impl -BoundedBy>, HatConv>> -for AutoConvolution> -where F : Float, - C : Constant, - S : Constant { - +impl + BoundedBy>, HatConv>> + for AutoConvolution> +where + F: Float, + C: Constant, + S: Constant, +{ fn bounding_factor( &self, - kernel : &SupportProductFirst>, HatConv> - ) -> Option { + kernel: &SupportProductFirst>, HatConv>, + ) -> DynResult { // We use the comparison $ℱ[𝒜(ψ v)] ≤ L_1 ℱ[𝒜(ψ)u] ⟺ I_{v̂} v̂ ≤ L_1 û$ with // $ψ = χ_{[-w, w]}$ satisfying $supp v ⊂ [-w, w]$, i.e. $w ≥ σ$. Here $v̂ = ℱ[v]$ and // $I_{v̂} = ∫ v̂ d ξ. For this relationship to be valid, we need $v̂ ≥ 0$, which is guaranteed @@ -574,10 +633,10 @@ // `SupportProductFirst>, HatConv>` // is wide enough, and that the hat convolution has the same radius as ours. if σ <= a && hatconv2 == &self.0 { - Some(bounding_1d.powi(N as i32)) + Ok(bounding_1d.powi(N as i32)) } else { // We cannot compare - None + Err(anyhow!("Incomparable factors")) } } } @@ -586,46 +645,44 @@ /// /// This is based on Example 3.3 in the manuscript. #[replace_float_literals(F::cast_from(literal))] -impl -BoundedBy> -for AutoConvolution> -where F : Float, - C : Constant { - +impl BoundedBy> for AutoConvolution> +where + F: Float, + C: Constant, +{ /// Returns an estimate of the factor $L_1$. /// /// Returns `None` if `kernel` does not have the same width as hat convolution that `self` /// is based on. - fn bounding_factor( - &self, - kernel : &HatConv - ) -> Option { + fn bounding_factor(&self, kernel: &HatConv) -> DynResult { if kernel == &self.0 { - Some(1.0) + Ok(1.0) } else { // We cannot compare - None + Err(anyhow!("Incomparable kernels")) } } } #[cfg(test)] mod tests { + use super::HatConv; + use crate::kernels::{BallIndicator, Convolution, CubeIndicator}; use alg_tools::lingrid::linspace; + use alg_tools::loc::Loc; use alg_tools::mapping::Mapping; use alg_tools::norms::Linfinity; - use alg_tools::loc::Loc; - use crate::kernels::{BallIndicator, CubeIndicator, Convolution}; - use super::HatConv; /// Tests numerically that [`HatConv`] is monotone. #[test] fn hatconv_monotonicity() { let grid = linspace(0.0, 1.0, 100000); - let hatconv : HatConv = HatConv{ radius : 1.0 }; + let hatconv: HatConv = HatConv { radius: 1.0 }; let mut vals = grid.into_iter().map(|t| hatconv.apply(Loc::from(t))); let first = vals.next().unwrap(); - let monotone = vals.fold((first, true), |(prev, ok), t| (prev, ok && prev >= t)).1; + let monotone = vals + .fold((first, true), |(prev, ok), t| (prev, ok && prev >= t)) + .1; assert!(monotone); } @@ -633,21 +690,27 @@ #[test] fn convolution_cubeind_hatconv_monotonicity() { let grid = linspace(-2.0, 0.0, 100000); - let hatconv : Convolution, HatConv> - = Convolution(BallIndicator { r : 0.5, exponent : Linfinity }, - HatConv{ radius : 1.0 } ); + let hatconv: Convolution, HatConv> = + Convolution(BallIndicator { r: 0.5, exponent: Linfinity }, HatConv { + radius: 1.0, + }); let mut vals = grid.into_iter().map(|t| hatconv.apply(Loc::from(t))); let first = vals.next().unwrap(); - let monotone = vals.fold((first, true), |(prev, ok), t| (prev, ok && prev <= t)).1; + let monotone = vals + .fold((first, true), |(prev, ok), t| (prev, ok && prev <= t)) + .1; assert!(monotone); let grid = linspace(0.0, 2.0, 100000); - let hatconv : Convolution, HatConv> - = Convolution(BallIndicator { r : 0.5, exponent : Linfinity }, - HatConv{ radius : 1.0 } ); + let hatconv: Convolution, HatConv> = + Convolution(BallIndicator { r: 0.5, exponent: Linfinity }, HatConv { + radius: 1.0, + }); let mut vals = grid.into_iter().map(|t| hatconv.apply(Loc::from(t))); let first = vals.next().unwrap(); - let monotone = vals.fold((first, true), |(prev, ok), t| (prev, ok && prev >= t)).1; + let monotone = vals + .fold((first, true), |(prev, ok), t| (prev, ok && prev >= t)) + .1; assert!(monotone); } } diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/linear.rs --- a/src/kernels/linear.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/linear.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,94 +1,82 @@ //! Implementation of the linear function +use alg_tools::bisection_tree::Support; +use alg_tools::bounds::{Bounded, Bounds, GlobalAnalysis, LocalAnalysis}; +use alg_tools::loc::Loc; +use alg_tools::norms::*; +use alg_tools::sets::Cube; +use alg_tools::types::*; use numeric_literals::replace_float_literals; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::norms::*; -use alg_tools::loc::Loc; -use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Bounds, - LocalAnalysis, - GlobalAnalysis, - Bounded, -}; -use alg_tools::mapping::{Mapping, Instance}; + +use alg_tools::euclidean::Euclidean; +use alg_tools::mapping::{Instance, Mapping}; use alg_tools::maputil::array_init; -use alg_tools::euclidean::Euclidean; /// Representation of the hat function $f(x)=1-\\|x\\|\_1/ε$ of `width` $ε$ on $ℝ^N$. -#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] -pub struct Linear { +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] +pub struct Linear { /// The parameter $ε>0$. - pub v : Loc, + pub v: Loc, } #[replace_float_literals(F::cast_from(literal))] -impl Mapping> for Linear { +impl Mapping> for Linear { type Codomain = F; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { x.eval(|x| self.v.dot(x)) } } - #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, const N : usize> Support for Linear { +impl<'a, F: Float, const N: usize> Support for Linear { #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { array_init(|| [F::NEG_INFINITY, F::INFINITY]).into() } #[inline] - fn in_support(&self, _x : &Loc) -> bool { + fn in_support(&self, _x: &Loc) -> bool { true } - + /*fn fully_in_support(&self, _cube : &Cube) -> bool { todo!("Not implemented, but not used at the moment") }*/ #[inline] - fn bisection_hint(&self, _cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, _cube: &Cube) -> [Option; N] { [None; N] } } - #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, const N : usize> -GlobalAnalysis> -for Linear { +impl<'a, F: Float, const N: usize> GlobalAnalysis> for Linear { #[inline] fn global_analysis(&self) -> Bounds { Bounds(F::NEG_INFINITY, F::INFINITY) } } -impl<'a, F : Float, const N : usize> -LocalAnalysis, N> -for Linear { +impl<'a, F: Float, const N: usize> LocalAnalysis, N> for Linear { #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { - let (lower, upper) = cube.iter_corners() - .map(|x| self.apply(x)) - .fold((F::INFINITY, F::NEG_INFINITY), |(lower, upper), v| { - (lower.min(v), upper.max(v)) - }); + fn local_analysis(&self, cube: &Cube) -> Bounds { + let (lower, upper) = cube + .iter_corners() + .map(|x| self.apply(x)) + .fold((F::INFINITY, F::NEG_INFINITY), |(lower, upper), v| { + (lower.min(v), upper.max(v)) + }); Bounds(lower, upper) } } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float, const N : usize> -Norm -for Linear { +impl<'a, F: Float, const N: usize> Norm for Linear { #[inline] - fn norm(&self, _ : Linfinity) -> F { + fn norm(&self, _: Linfinity) -> F { self.bounds().upper() } } - diff -r 9738b51d90d7 -r 4f468d35fa29 src/kernels/mollifier.rs --- a/src/kernels/mollifier.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/kernels/mollifier.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,24 +1,17 @@ - //! Implementation of the standard mollifier -use rgsl::hypergeometric::hyperg_U; +use alg_tools::bisection_tree::{Bounds, Constant, GlobalAnalysis, LocalAnalysis, Support}; +use alg_tools::euclidean::Euclidean; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Instance, Mapping}; +use alg_tools::maputil::array_init; +use alg_tools::norms::*; +use alg_tools::sets::Cube; +use alg_tools::types::*; use float_extras::f64::tgamma as gamma; use numeric_literals::replace_float_literals; +use rgsl::hypergeometric::hyperg_U; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::euclidean::Euclidean; -use alg_tools::norms::*; -use alg_tools::loc::Loc; -use alg_tools::sets::Cube; -use alg_tools::bisection_tree::{ - Support, - Constant, - Bounds, - LocalAnalysis, - GlobalAnalysis -}; -use alg_tools::mapping::{Mapping, Instance}; -use alg_tools::maputil::array_init; /// Reresentation of the (unnormalised) standard mollifier. /// @@ -29,20 +22,20 @@ /// 0, & \text{otherwise}. /// \end{cases} /// $$
-#[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] -pub struct Mollifier { +#[derive(Copy, Clone, Serialize, Debug, Eq, PartialEq)] +pub struct Mollifier { /// The parameter $ε$ of the mollifier. - pub width : C, + pub width: C, } #[replace_float_literals(C::Type::cast_from(literal))] -impl Mapping> for Mollifier { +impl Mapping> for Mollifier { type Codomain = C::Type; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { let ε = self.width.value(); - let ε2 = ε*ε; + let ε2 = ε * ε; let n2 = x.eval(|x| x.norm2_squared()); if n2 < ε2 { (n2 / (n2 - ε2)).exp() @@ -52,27 +45,25 @@ } } - -impl<'a, C : Constant, const N : usize> Support for Mollifier { +impl<'a, C: Constant, const N: usize> Support for Mollifier { #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { let ε = self.width.value(); array_init(|| [-ε, ε]).into() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { x.norm2() < self.width.value() } - + /*fn fully_in_support(&self, _cube : &Cube) -> bool { todo!("Not implemented, but not used at the moment") }*/ } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> GlobalAnalysis> -for Mollifier { +impl<'a, C: Constant, const N: usize> GlobalAnalysis> for Mollifier { #[inline] fn global_analysis(&self) -> Bounds { // The function is maximised/minimised where the 2-norm is minimised/maximised. @@ -80,10 +71,11 @@ } } -impl<'a, C : Constant, const N : usize> LocalAnalysis, N> -for Mollifier { +impl<'a, C: Constant, const N: usize> LocalAnalysis, N> + for Mollifier +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { // The function is maximised/minimised where the 2-norm is minimised/maximised. let lower = self.apply(cube.maxnorm_point()); let upper = self.apply(cube.minnorm_point()); @@ -99,7 +91,7 @@ /// If `rescaled` is `true`, return the integral of the scaled mollifier that has value one at the /// origin. #[inline] -pub fn mollifier_norm1(n_ : usize, rescaled : bool) -> f64 { +pub fn mollifier_norm1(n_: usize, rescaled: bool) -> f64 { assert!(n_ > 0); let n = n_ as f64; let q = 2.0; @@ -108,23 +100,25 @@ /*/ gamma(1.0 + n / p) * gamma(1.0 + n / q)*/ * hyperg_U(1.0 + n / q, 2.0, 1.0); - if rescaled { base } else { base / f64::E } + if rescaled { + base + } else { + base / f64::E + } } -impl<'a, C : Constant, const N : usize> Norm -for Mollifier { +impl<'a, C: Constant, const N: usize> Norm for Mollifier { #[inline] - fn norm(&self, _ : L1) -> C::Type { + fn norm(&self, _: L1) -> C::Type { let ε = self.width.value(); C::Type::cast_from(mollifier_norm1(N, true)) * ε.powi(N as i32) } } #[replace_float_literals(C::Type::cast_from(literal))] -impl<'a, C : Constant, const N : usize> Norm -for Mollifier { +impl<'a, C: Constant, const N: usize> Norm for Mollifier { #[inline] - fn norm(&self, _ : Linfinity) -> C::Type { + fn norm(&self, _: Linfinity) -> C::Type { 1.0 } } diff -r 9738b51d90d7 -r 4f468d35fa29 src/lib.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/lib.rs Thu Feb 26 11:38:43 2026 -0500 @@ -0,0 +1,280 @@ +// The main documentation is in the README. +// We need to uglify it in build.rs because rustdoc is stuck in the past. +#![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] +// We use unicode. We would like to use much more of it than Rust allows. +// Live with it. Embrace it. +#![allow(uncommon_codepoints)] +#![allow(mixed_script_confusables)] +#![allow(confusable_idents)] +// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical +// convention while referring to the type (trait) of the operator as `A`. +#![allow(non_snake_case)] +// Need to create parse errors +#![feature(dec2flt)] + +use alg_tools::error::DynResult; +use alg_tools::parallelism::{set_max_threads, set_num_threads}; +use clap::Parser; +use serde::{Deserialize, Serialize}; +use serde_json; +use serde_with::skip_serializing_none; +use std::num::NonZeroUsize; + +//#[cfg(feature = "pyo3")] +//use pyo3::pyclass; + +pub mod dataterm; +pub mod experiments; +pub mod fb; +pub mod forward_model; +pub mod forward_pdps; +pub mod fourier; +pub mod frank_wolfe; +pub mod kernels; +pub mod pdps; +pub mod plot; +pub mod preadjoint_helper; +pub mod prox_penalty; +pub mod rand_distr; +pub mod regularisation; +pub mod run; +pub mod seminorms; +pub mod sliding_fb; +pub mod sliding_pdps; +pub mod subproblem; +pub mod tolerance; +pub mod types; + +pub mod measures { + pub use measures::*; +} + +use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment}; +use types::{ClapFloat, Float}; +use DefaultAlgorithm::*; + +/// Trait for customising the experiments available from the command line +pub trait ExperimentSetup: + clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> +{ + /// Type of floating point numbers to be used. + type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>; + + fn runnables(&self) -> DynResult>>>; +} + +/// Command line parameters +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Default, Clone)] +pub struct CommandLineArgs { + #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] + /// Maximum iteration count + max_iter: usize, + + #[arg(long, short = 'n', value_name = "N")] + /// Output status every N iterations. Set to 0 to disable. + /// + /// The default is to output status based on logarithmic increments. + verbose_iter: Option, + + #[arg(long, short = 'q')] + /// Don't display iteration progress + quiet: bool, + + /// Default algorithm configration(s) to use on the experiments. + /// + /// Not all algorithms are available for all the experiments. + /// In particular, only PDPS is available for the experiments with L¹ data term. + #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', + default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] + algorithm: Vec, + + /// Saved algorithm configration(s) to use on the experiments + #[arg(value_name = "JSON_FILE", long)] + saved_algorithm: Vec, + + /// Plot saving scheme + #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] + plot: PlotLevel, + + /// Directory for saving results + #[arg(long, short = 'o', required = true, default_value = "out")] + outdir: String, + + #[arg(long, help_heading = "Multi-threading", default_value = "4")] + /// Maximum number of threads + max_threads: usize, + + #[arg(long, help_heading = "Multi-threading")] + /// Number of threads. Overrides the maximum number. + num_threads: Option, + + #[arg(long, default_value_t = false)] + /// Load saved value ranges (if exists) to do partial update. + load_valuerange: bool, +} + +#[derive(Parser, Debug, Serialize, Default, Clone)] +#[clap( + about = env!("CARGO_PKG_DESCRIPTION"), + author = env!("CARGO_PKG_AUTHORS"), + version = env!("CARGO_PKG_VERSION"), + after_help = "Pass --help for longer descriptions.", + after_long_help = "", +)] +struct FusedCommandLineArgs { + /// List of experiments to perform. + #[clap(flatten, next_help_heading = "Experiment setup")] + experiment_setup: E, + + #[clap(flatten, next_help_heading = "General parameters")] + general: CommandLineArgs, + + #[clap(flatten, next_help_heading = "Algorithm overrides")] + /// Algorithm parametrisation overrides + algorithm_overrides: AlgorithmOverrides, +} + +/// Command line algorithm parametrisation overrides +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] +//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] +pub struct AlgorithmOverrides { + #[arg(long, value_names = &["COUNT", "EACH"])] + /// Override bootstrap insertion iterations for --algorithm. + /// + /// The first parameter is the number of bootstrap insertion iterations, and the second + /// the maximum number of iterations on each of them. + bootstrap_insertions: Option>, + + #[arg(long, requires = "algorithm")] + /// Primal step length parameter override for --algorithm. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. Does not affect the algorithms fw and fwrelax. + tau0: Option, + + #[arg(long, requires = "algorithm")] + /// Second primal step length parameter override for SlidingPDPS. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. + sigmap0: Option, + + #[arg(long, requires = "algorithm")] + /// Dual step length parameter override for --algorithm. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. Only affects PDPS. + sigma0: Option, + + #[arg(long)] + /// Normalised transport step length for sliding methods. + theta0: Option, + + #[arg(long)] + /// A posteriori transport tolerance multiplier (C_pos) + transport_tolerance_pos: Option, + + #[arg(long)] + /// Transport adaptation factor. Must be in (0, 1). + transport_adaptation: Option, + + #[arg(long)] + /// Minimal step length parameter for sliding methods. + tau0_min: Option, + + #[arg(value_enum, long)] + /// PDPS acceleration, when available. + acceleration: Option, + + // #[arg(long)] + // /// Perform postprocess weight optimisation for saved iterations + // /// + // /// Only affects FB, FISTA, and PDPS. + // postprocessing : Option, + #[arg(value_name = "n", long)] + /// Merging frequency, if merging enabled (every n iterations) + /// + /// Only affects FB, FISTA, and PDPS. + merge_every: Option, + + #[arg(long)] + /// Enable merging (default: determined by algorithm) + merge: Option, + + #[arg(long)] + /// Merging radius (default: determined by experiment) + merge_radius: Option, + + #[arg(long)] + /// Interpolate when merging (default : determined by algorithm) + merge_interp: Option, + + #[arg(long)] + /// Enable final merging (default: determined by algorithm) + final_merging: Option, + + #[arg(long)] + /// Enable fitness-based merging for relevant FB-type methods. + /// This has worse convergence guarantees that merging based on optimality conditions. + fitness_merging: Option, + + #[arg(long, value_names = &["ε", "θ", "p"])] + /// Set the tolerance to ε_k = ε/(1+θk)^p + tolerance: Option>, +} + +/// A generic entry point for binaries based on this library +pub fn common_main() -> DynResult<()> { + let full_cli = FusedCommandLineArgs::::parse(); + let cli = &full_cli.general; + + #[cfg(debug_assertions)] + { + use colored::Colorize; + println!( + "{}", + format!( + "\n\ + ********\n\ + WARNING: Compiled without optimisations; {}\n\ + Please recompile with `--release` flag.\n\ + ********\n\ + ", + "performance will be poor!".blink() + ) + .red() + ); + } + + if let Some(n_threads) = cli.num_threads { + let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); + set_num_threads(n); + } else { + let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); + set_max_threads(m); + } + + for experiment in full_cli.experiment_setup.runnables()? { + let mut algs: Vec>> = cli + .algorithm + .iter() + .map(|alg| { + let cfg = alg + .default_config() + .cli_override(&experiment.algorithm_overrides(*alg)) + .cli_override(&full_cli.algorithm_overrides); + alg.to_named(cfg) + }) + .collect(); + for filename in cli.saved_algorithm.iter() { + let f = std::fs::File::open(filename)?; + let alg = serde_json::from_reader(f)?; + algs.push(alg); + } + experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?; + } + + Ok(()) +} diff -r 9738b51d90d7 -r 4f468d35fa29 src/main.rs --- a/src/main.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/main.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,293 +2,9 @@ // We need to uglify it in build.rs because rustdoc is stuck in the past. #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] -// We use unicode. We would like to use much more of it than Rust allows. -// Live with it. Embrace it. -#![allow(uncommon_codepoints)] -#![allow(mixed_script_confusables)] -#![allow(confusable_idents)] -// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical -// convention while referring to the type (trait) of the operator as `A`. -#![allow(non_snake_case)] -// Need to create parse errors -#![feature(dec2flt)] - -use clap::Parser; -use serde::{Serialize, Deserialize}; -use serde_json; -use serde_with::skip_serializing_none; -use itertools::Itertools; -use std::num::NonZeroUsize; - -use alg_tools::parallelism::{ - set_num_threads, - set_max_threads, -}; - -pub mod types; -pub mod measures; -pub mod fourier; -pub mod kernels; -pub mod seminorms; -pub mod forward_model; -pub mod preadjoint_helper; -pub mod plot; -pub mod subproblem; -pub mod tolerance; -pub mod regularisation; -pub mod dataterm; -pub mod prox_penalty; -pub mod fb; -pub mod sliding_fb; -pub mod sliding_pdps; -pub mod forward_pdps; -pub mod frank_wolfe; -pub mod pdps; -pub mod run; -pub mod rand_distr; -pub mod experiments; - -use types::{float, ClapFloat}; -use run::{ - DefaultAlgorithm, - PlotLevel, - Named, - AlgorithmConfig, -}; -use experiments::DefaultExperiment; -use DefaultExperiment::*; -use DefaultAlgorithm::*; - -/// Command line parameters -#[skip_serializing_none] -#[derive(Parser, Debug, Serialize, Default, Clone)] -#[clap( - about = env!("CARGO_PKG_DESCRIPTION"), - author = env!("CARGO_PKG_AUTHORS"), - version = env!("CARGO_PKG_VERSION"), - after_help = "Pass --help for longer descriptions.", - after_long_help = "", -)] -pub struct CommandLineArgs { - #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] - /// Maximum iteration count - max_iter : usize, - - #[arg(long, short = 'n', value_name = "N")] - /// Output status every N iterations. Set to 0 to disable. - /// - /// The default is to output status based on logarithmic increments. - verbose_iter : Option, - - #[arg(long, short = 'q')] - /// Don't display iteration progress - quiet : bool, - - /// List of experiments to perform. - #[arg(value_enum, value_name = "EXPERIMENT", - default_values_t = [Experiment1D, Experiment1DFast, - Experiment2D, Experiment2DFast, - Experiment1D_L1])] - experiments : Vec, - - /// Default algorithm configration(s) to use on the experiments. - /// - /// Not all algorithms are available for all the experiments. - /// In particular, only PDPS is available for the experiments with L¹ data term. - #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', - default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] - algorithm : Vec, - - /// Saved algorithm configration(s) to use on the experiments - #[arg(value_name = "JSON_FILE", long)] - saved_algorithm : Vec, - - /// Plot saving scheme - #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] - plot : PlotLevel, - - /// Directory for saving results - #[arg(long, short = 'o', required = true, default_value = "out")] - outdir : String, - - #[arg(long, help_heading = "Multi-threading", default_value = "4")] - /// Maximum number of threads - max_threads : usize, - - #[arg(long, help_heading = "Multi-threading")] - /// Number of threads. Overrides the maximum number. - num_threads : Option, - - #[arg(long, default_value_t = false)] - /// Load saved value ranges (if exists) to do partial update. - load_valuerange : bool, - - #[clap(flatten, next_help_heading = "Experiment overrides")] - /// Experiment setup overrides - experiment_overrides : ExperimentOverrides, - - #[clap(flatten, next_help_heading = "Algorithm overrides")] - /// Algorithm parametrisation overrides - algoritm_overrides : AlgorithmOverrides, -} - -/// Command line experiment setup overrides -#[skip_serializing_none] -#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] -pub struct ExperimentOverrides { - #[arg(long)] - /// Regularisation parameter override. - /// - /// Only use if running just a single experiment, as different experiments have different - /// regularisation parameters. - alpha : Option, - - #[arg(long)] - /// Gaussian noise variance override - variance : Option, - - #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] - /// Salt and pepper noise override. - salt_and_pepper : Option>, - - #[arg(long)] - /// Noise seed - noise_seed : Option, -} - -/// Command line algorithm parametrisation overrides -#[skip_serializing_none] -#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] -pub struct AlgorithmOverrides { - #[arg(long, value_names = &["COUNT", "EACH"])] - /// Override bootstrap insertion iterations for --algorithm. - /// - /// The first parameter is the number of bootstrap insertion iterations, and the second - /// the maximum number of iterations on each of them. - bootstrap_insertions : Option>, - - #[arg(long, requires = "algorithm")] - /// Primal step length parameter override for --algorithm. - /// - /// Only use if running just a single algorithm, as different algorithms have different - /// regularisation parameters. Does not affect the algorithms fw and fwrelax. - tau0 : Option, - - #[arg(long, requires = "algorithm")] - /// Second primal step length parameter override for SlidingPDPS. - /// - /// Only use if running just a single algorithm, as different algorithms have different - /// regularisation parameters. - sigmap0 : Option, - - #[arg(long, requires = "algorithm")] - /// Dual step length parameter override for --algorithm. - /// - /// Only use if running just a single algorithm, as different algorithms have different - /// regularisation parameters. Only affects PDPS. - sigma0 : Option, - - #[arg(long)] - /// Normalised transport step length for sliding methods. - theta0 : Option, - - #[arg(long)] - /// A posteriori transport tolerance multiplier (C_pos) - transport_tolerance_pos : Option, - - #[arg(long)] - /// Transport adaptation factor. Must be in (0, 1). - transport_adaptation : Option, - - #[arg(long)] - /// Minimal step length parameter for sliding methods. - tau0_min : Option, - - #[arg(value_enum, long)] - /// PDPS acceleration, when available. - acceleration : Option, - - // #[arg(long)] - // /// Perform postprocess weight optimisation for saved iterations - // /// - // /// Only affects FB, FISTA, and PDPS. - // postprocessing : Option, - - #[arg(value_name = "n", long)] - /// Merging frequency, if merging enabled (every n iterations) - /// - /// Only affects FB, FISTA, and PDPS. - merge_every : Option, - - #[arg(long)] - /// Enable merging (default: determined by algorithm) - merge : Option, - - #[arg(long)] - /// Merging radius (default: determined by experiment) - merge_radius : Option, - - #[arg(long)] - /// Interpolate when merging (default : determined by algorithm) - merge_interp : Option, - - #[arg(long)] - /// Enable final merging (default: determined by algorithm) - final_merging : Option, - - #[arg(long)] - /// Enable fitness-based merging for relevant FB-type methods. - /// This has worse convergence guarantees that merging based on optimality conditions. - fitness_merging : Option, - - #[arg(long, value_names = &["ε", "θ", "p"])] - /// Set the tolerance to ε_k = ε/(1+θk)^p - tolerance : Option>, - -} +use pointsource_algs::{common_main, experiments::DefaultExperimentSetup}; /// The entry point for the program. pub fn main() { - let cli = CommandLineArgs::parse(); - - #[cfg(debug_assertions)] - { - use colored::Colorize; - println!("{}", format!("\n\ - ********\n\ - WARNING: Compiled without optimisations; {}\n\ - Please recompile with `--release` flag.\n\ - ********\n\ - ", "performance will be poor!".blink() - ).red()); - } - - if let Some(n_threads) = cli.num_threads { - let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); - set_num_threads(n); - } else { - let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); - set_max_threads(m); - } - - for experiment_shorthand in cli.experiments.iter().unique() { - let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); - let mut algs : Vec>> - = cli.algorithm - .iter() - .map(|alg| { - let cfg = alg.default_config() - .cli_override(&experiment.algorithm_overrides(*alg)) - .cli_override(&cli.algoritm_overrides); - alg.to_named(cfg) - }) - .collect(); - for filename in cli.saved_algorithm.iter() { - let f = std::fs::File::open(filename).unwrap(); - let alg = serde_json::from_reader(f).unwrap(); - algs.push(alg); - } - experiment.runall(&cli, (!algs.is_empty()).then_some(algs)) - .unwrap() - } + common_main::>().unwrap(); } diff -r 9738b51d90d7 -r 4f468d35fa29 src/measures.rs --- a/src/measures.rs Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,10 +0,0 @@ -//! This module implementes measures, in particular [`DeltaMeasure`]s and [`DiscreteMeasure`]s. - -mod base; -pub use base::*; -mod delta; -pub use delta::*; -mod discrete; -pub use discrete::*; -pub mod merging; - diff -r 9738b51d90d7 -r 4f468d35fa29 src/measures/base.rs --- a/src/measures/base.rs Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,21 +0,0 @@ -//! Basic definitions for measures - -use serde::Serialize; -use alg_tools::types::Num; -use alg_tools::norms::{Norm, NormExponent}; - -/// This is used with [`Norm::norm`] to indicate that a Radon norm is to be computed. -#[derive(Copy,Clone,Serialize,Debug)] -pub struct Radon; -impl NormExponent for Radon {} - -/// A trait for (Radon) measures. -/// -/// Currently has no methods, just the requirement that the Radon norm be implemented. -pub trait Measure : Norm { - type Domain; -} - -/// Decomposition of measures -pub struct MeasureDecomp; - diff -r 9738b51d90d7 -r 4f468d35fa29 src/measures/delta.rs --- a/src/measures/delta.rs Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,320 +0,0 @@ -/*! -This module implementes delta measures, i.e., single spikes $\alpha \delta_x$ for some -location $x$ and mass $\alpha$. -*/ - -use super::base::*; -use crate::types::*; -use std::ops::{Div, Mul, DivAssign, MulAssign, Neg}; -use serde::ser::{Serialize, Serializer, SerializeStruct}; -use alg_tools::norms::Norm; -use alg_tools::linops::{Mapping, Linear}; -use alg_tools::instance::{Instance, Space}; - -/// Representation of a delta measure. -/// -/// This is a single spike $\alpha \delta\_x$ for some location $x$ in `Domain` and -/// a mass $\alpha$ in `F`. -#[derive(Clone,Copy,Debug)] -pub struct DeltaMeasure { - // This causes [`csv`] to crash. - //#[serde(flatten)] - /// Location of the spike - pub x : Domain, - /// Mass of the spike - pub α : F -} - -const COORDINATE_NAMES : &'static [&'static str] = &[ - "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7" -]; - -// Need to manually implement serialisation as [`csv`] writer fails on -// structs with nested arrays as well as with #[serde(flatten)]. -impl Serialize for DeltaMeasure, F> -where - F: Serialize, -{ - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - assert!(N <= COORDINATE_NAMES.len()); - - let mut s = serializer.serialize_struct("DeltaMeasure", N+1)?; - for (i, e) in (0..).zip(self.x.iter()) { - s.serialize_field(COORDINATE_NAMES[i], e)?; - } - s.serialize_field("weight", &self.α)?; - s.end() - } -} - - -impl Measure for DeltaMeasure { - type Domain = Domain; -} - -impl Norm for DeltaMeasure { - #[inline] - fn norm(&self, _ : Radon) -> F { - self.α.abs() - } -} - -// impl Dist for DeltaMeasure { -// #[inline] -// fn dist(&self, other : &Self, _ : Radon) -> F { -// if self.x == other. x { -// (self.α - other.α).abs() -// } else { -// self.α.abs() + other.α.abs() -// } -// } -// } - -impl Mapping for DeltaMeasure -where - Domain : Space, - G::Codomain : Mul, - G : Mapping + Clone + Space, - for<'b> &'b Domain : Instance, -{ - type Codomain = G::Codomain; - - #[inline] - fn apply>(&self, g : I) -> Self::Codomain { - g.eval(|g̃| g̃.apply(&self.x) * self.α) - } -} - -impl Linear for DeltaMeasure -where - Domain : Space, - G::Codomain : Mul, - G : Mapping + Clone + Space, - for<'b> &'b Domain : Instance, -{ } - -// /// Partial blanket implementation of [`DeltaMeasure`] as a linear functional of [`Mapping`]s. -// /// A full blanket implementation is not possible due to annoying Rust limitations: only [`Apply`] -// /// on a reference is implemented, but a consuming [`Apply`] has to be implemented on a case-by-case -// /// basis, not because an implementation could not be written, but because the Rust trait system -// /// chokes up. -// impl Linear for DeltaMeasure -// where G: for<'a> Apply<&'a Domain, Output = V>, -// V : Mul, -// Self: Apply>::Output> { -// type Codomain = >::Output; -// } - -// impl<'b, Domain, G, F : Num, V> Apply<&'b G> for DeltaMeasure -// where G: for<'a> Apply<&'a Domain, Output = V>, -// V : Mul { -// type Output = >::Output; - -// #[inline] -// fn apply(&self, g : &'b G) -> Self::Output { -// g.apply(&self.x) * self.α -// } -// } - -// /// Implementation of the necessary apply for BTFNs -// mod btfn_apply { -// use super::*; -// use alg_tools::bisection_tree::{BTFN, BTImpl, SupportGenerator, LocalAnalysis}; - -// impl Apply> -// for DeltaMeasure, F> -// where BT : BTImpl, -// G : SupportGenerator, -// G::SupportType : LocalAnalysis + for<'a> Apply<&'a Loc, Output = V>, -// V : std::iter::Sum + Mul { - -// type Output = >::Output; - -// #[inline] -// fn apply(&self, g : BTFN) -> Self::Output { -// g.apply(&self.x) * self.α -// } -// } -// } - - -impl From<(D, F)> for DeltaMeasure -where D : Into { - #[inline] - fn from((x, α) : (D, F)) -> Self { - DeltaMeasure{x: x.into(), α: α} - } -} - -impl<'a, Domain : Clone, F : Num> From<&'a DeltaMeasure> for DeltaMeasure { - #[inline] - fn from(d : &'a DeltaMeasure) -> Self { - d.clone() - } -} - - -impl DeltaMeasure { - /// Set the mass of the spike. - #[inline] - pub fn set_mass(&mut self, α : F) { - self.α = α - } - - /// Set the location of the spike. - #[inline] - pub fn set_location(&mut self, x : Domain) { - self.x = x - } - - /// Get the mass of the spike. - #[inline] - pub fn get_mass(&self) -> F { - self.α - } - - /// Get a mutable reference to the mass of the spike. - #[inline] - pub fn get_mass_mut(&mut self) -> &mut F { - &mut self.α - } - - /// Get a reference to the location of the spike. - #[inline] - pub fn get_location(&self) -> &Domain { - &self.x - } - - /// Get a mutable reference to the location of the spike. - #[inline] - pub fn get_location_mut(&mut self) -> &mut Domain { - &mut self.x - } -} - -impl IntoIterator for DeltaMeasure { - type Item = Self; - type IntoIter = std::iter::Once; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - std::iter::once(self) - } -} - -impl<'a, Domain, F : Num> IntoIterator for &'a DeltaMeasure { - type Item = Self; - type IntoIter = std::iter::Once; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - std::iter::once(self) - } -} - - -macro_rules! make_delta_scalarop_rhs { - ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl $trait for DeltaMeasure { - type Output = Self; - fn $fn(mut self, b : F) -> Self { - self.α.$fn_assign(b); - self - } - } - - impl<'a, F : Num, Domain> $trait<&'a F> for DeltaMeasure { - type Output = Self; - fn $fn(mut self, b : &'a F) -> Self { - self.α.$fn_assign(*b); - self - } - } - - impl<'b, F : Num, Domain : Clone> $trait for &'b DeltaMeasure { - type Output = DeltaMeasure; - fn $fn(self, b : F) -> Self::Output { - DeltaMeasure { α : self.α.$fn(b), x : self.x.clone() } - } - } - - impl<'a, 'b, F : Num, Domain : Clone> $trait<&'a F> for &'b DeltaMeasure { - type Output = DeltaMeasure; - fn $fn(self, b : &'a F) -> Self::Output { - DeltaMeasure { α : self.α.$fn(*b), x : self.x.clone() } - } - } - - impl $trait_assign for DeltaMeasure { - fn $fn_assign(&mut self, b : F) { - self.α.$fn_assign(b) - } - } - - impl<'a, F : Num, Domain> $trait_assign<&'a F> for DeltaMeasure { - fn $fn_assign(&mut self, b : &'a F) { - self.α.$fn_assign(*b) - } - } - } -} - -make_delta_scalarop_rhs!(Mul, mul, MulAssign, mul_assign); -make_delta_scalarop_rhs!(Div, div, DivAssign, div_assign); - -macro_rules! make_delta_scalarop_lhs { - ($trait:ident, $fn:ident; $($f:ident)+) => { $( - impl $trait> for $f { - type Output = DeltaMeasure; - fn $fn(self, mut δ : DeltaMeasure) -> Self::Output { - δ.α = self.$fn(δ.α); - δ - } - } - - impl<'a, Domain : Clone> $trait<&'a DeltaMeasure> for $f { - type Output = DeltaMeasure; - fn $fn(self, δ : &'a DeltaMeasure) -> Self::Output { - DeltaMeasure{ x : δ.x.clone(), α : self.$fn(δ.α) } - } - } - - impl<'b, Domain> $trait> for &'b $f { - type Output = DeltaMeasure; - fn $fn(self, mut δ : DeltaMeasure) -> Self::Output { - δ.α = self.$fn(δ.α); - δ - } - } - - impl<'a, 'b, Domain : Clone> $trait<&'a DeltaMeasure> for &'b $f { - type Output = DeltaMeasure; - fn $fn(self, δ : &'a DeltaMeasure) -> Self::Output { - DeltaMeasure{ x : δ.x.clone(), α : self.$fn(δ.α) } - } - } - )+ } -} - -make_delta_scalarop_lhs!(Mul, mul; f32 f64 i8 i16 i32 i64 isize u8 u16 u32 u64 usize); -make_delta_scalarop_lhs!(Div, div; f32 f64 i8 i16 i32 i64 isize u8 u16 u32 u64 usize); - -macro_rules! make_delta_unary { - ($trait:ident, $fn:ident, $type:ty) => { - impl<'a, F : Num + Neg, Domain : Clone> Neg for $type { - type Output = DeltaMeasure; - fn $fn(self) -> Self::Output { - let mut tmp = self.clone(); - tmp.α = tmp.α.$fn(); - tmp - } - } - } -} - -make_delta_unary!(Neg, neg, DeltaMeasure); -make_delta_unary!(Neg, neg, &'a DeltaMeasure); - diff -r 9738b51d90d7 -r 4f468d35fa29 src/measures/discrete.rs --- a/src/measures/discrete.rs Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,1014 +0,0 @@ -//! This module implementes discrete measures. - -use std::ops::{ - Div,Mul,DivAssign,MulAssign,Neg, - Add,Sub,AddAssign,SubAssign, - Index,IndexMut, -}; -use std::iter::Sum; -use serde::ser::{Serializer, Serialize, SerializeSeq}; -use nalgebra::DVector; - -use alg_tools::norms::Norm; -use alg_tools::tabledump::TableDump; -use alg_tools::linops::{Mapping, Linear}; -use alg_tools::iter::{MapF,Mappable}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::collection::Collection; -use alg_tools::instance::{Instance, Decomposition, MyCow, EitherDecomp, Space}; - -use crate::types::*; -use super::base::*; -use super::delta::*; - -/// Representation of a discrete measure. -/// -/// This is the measure $μ = ∑_{k=1}^n α_k δ_{x_k}$, consisting of several -/// [`DeltaMeasure`], i.e., “spikes” $α_k δ_{x_k}$ with weights $\alpha_k$ in `F` at locations -/// $x_k$ in `Domain`. -#[derive(Clone,Debug)] -pub struct DiscreteMeasure { - pub(super) spikes : Vec>, -} - -pub type RNDM = DiscreteMeasure, F>; - -/// Iterator over the [`DeltaMeasure`] spikes of a [`DiscreteMeasure`]. -pub type SpikeIter<'a, Domain, F> = std::slice::Iter<'a, DeltaMeasure>; - -/// Iterator over mutable [`DeltaMeasure`] spikes of a [`DiscreteMeasure`]. -pub type SpikeIterMut<'a, Domain, F> = std::slice::IterMut<'a, DeltaMeasure>; - -/// Iterator over the locations of the spikes of a [`DiscreteMeasure`]. -pub type LocationIter<'a, Domain, F> - = std::iter::Map, fn(&'a DeltaMeasure) -> &'a Domain>; - -/// Iterator over the masses of the spikes of a [`DiscreteMeasure`]. -pub type MassIter<'a, Domain, F> - = std::iter::Map, fn(&'a DeltaMeasure) -> F>; - -/// Iterator over the mutable locations of the spikes of a [`DiscreteMeasure`]. -pub type MassIterMut<'a, Domain, F> - = std::iter::Map, for<'r> fn(&'r mut DeltaMeasure) -> &'r mut F>; - -impl DiscreteMeasure { - /// Create a new zero measure (empty spike set). - pub fn new() -> Self { - DiscreteMeasure{ spikes : Vec::new() } - } - - /// Number of [`DeltaMeasure`] spikes in the measure - #[inline] - pub fn len(&self) -> usize { - self.spikes.len() - } - - /// Replace with the zero measure. - #[inline] - pub fn clear(&mut self) { - self.spikes.clear() - } - - /// Remove `i`:th spike, not maintaining order. - /// - /// Panics if indiex is out of bounds. - #[inline] - pub fn swap_remove(&mut self, i : usize) -> DeltaMeasure{ - self.spikes.swap_remove(i) - } - - /// Iterate over (references to) the [`DeltaMeasure`] spikes in this measure - #[inline] - pub fn iter_spikes(&self) -> SpikeIter<'_, Domain, F> { - self.spikes.iter() - } - - /// Iterate over mutable references to the [`DeltaMeasure`] spikes in this measure - #[inline] - pub fn iter_spikes_mut(&mut self) -> SpikeIterMut<'_, Domain, F> { - self.spikes.iter_mut() - } - - /// Iterate over the location of the spikes in this measure - #[inline] - pub fn iter_locations(&self) -> LocationIter<'_, Domain, F> { - self.iter_spikes().map(DeltaMeasure::get_location) - } - - /// Iterate over the masses of the spikes in this measure - #[inline] - pub fn iter_masses(&self) -> MassIter<'_, Domain, F> { - self.iter_spikes().map(DeltaMeasure::get_mass) - } - - /// Iterate over the masses of the spikes in this measure - #[inline] - pub fn iter_masses_mut(&mut self) -> MassIterMut<'_, Domain, F> { - self.iter_spikes_mut().map(DeltaMeasure::get_mass_mut) - } - - /// Update the masses of all the spikes to those produced by an iterator. - #[inline] - pub fn set_masses>(&mut self, iter : I) { - self.spikes.iter_mut().zip(iter).for_each(|(δ, α)| δ.set_mass(α)); - } - - /// Update the locations of all the spikes to those produced by an iterator. - #[inline] - pub fn set_locations<'a, I : Iterator>(&mut self, iter : I) - where Domain : 'static + Clone { - self.spikes.iter_mut().zip(iter.cloned()).for_each(|(δ, α)| δ.set_location(α)); - } - - // /// Map the masses of all the spikes using a function and an iterator - // #[inline] - // pub fn zipmap_masses< - // I : Iterator, - // G : Fn(F, I::Item) -> F - // > (&mut self, iter : I, g : G) { - // self.spikes.iter_mut().zip(iter).for_each(|(δ, v)| δ.set_mass(g(δ.get_mass(), v))); - // } - - /// Prune all spikes with zero mass. - #[inline] - pub fn prune(&mut self) { - self.prune_by(|δ| δ.α != F::ZERO); - } - - /// Prune spikes by the predicate `g`. - #[inline] - pub fn prune_by) -> bool>(&mut self, g : G) { - self.spikes.retain(g); - } - - /// Add the spikes produced by `iter` to this measure. - #[inline] - pub fn extend>>( - &mut self, - iter : I - ) { - self.spikes.extend(iter); - } - - /// Add a spike to the measure - #[inline] - pub fn push(&mut self, δ : DeltaMeasure) { - self.spikes.push(δ); - } - - /// Iterate over triples of masses and locations of two discrete measures, which are assumed - /// to have equal locations of same spike indices. - pub fn both_matching<'a>(&'a self, other : &'a DiscreteMeasure) -> - impl Iterator { - let m = self.len().max(other.len()); - self.iter_spikes().map(Some).chain(std::iter::repeat(None)) - .zip(other.iter_spikes().map(Some).chain(std::iter::repeat(None))) - .take(m) - .map(|(oδ, orδ)| { - match (oδ, orδ) { - (Some(δ), Some(rδ)) => (δ.α, rδ.α, &δ.x), // Assumed δ.x=rδ.x - (Some(δ), None) => (δ.α, F::ZERO, &δ.x), - (None, Some(rδ)) => (F::ZERO, rδ.α, &rδ.x), - (None, None) => panic!("This cannot happen!"), - } - }) - } - - /// Subtract `other` from `self`, assuming equal locations of same spike indices - pub fn sub_matching(&self, other : &DiscreteMeasure) -> DiscreteMeasure - where Domain : Clone { - self.both_matching(other) - .map(|(α, β, x)| (x.clone(), α - β)) - .collect() - } - - /// Add `other` to `self`, assuming equal locations of same spike indices - pub fn add_matching(&self, other : &DiscreteMeasure) -> DiscreteMeasure - where Domain : Clone { - self.both_matching(other) - .map(|(α, β, x)| (x.clone(), α + β)) - .collect() - } - - /// Calculate the Radon-norm distance of `self` to `other`, - /// assuming equal locations of same spike indices. - pub fn dist_matching(&self, other : &DiscreteMeasure) -> F where F : Float { - self.both_matching(other) - .map(|(α, β, _)| (α-β).abs()) - .sum() - } -} - -impl IntoIterator for DiscreteMeasure { - type Item = DeltaMeasure; - type IntoIter = std::vec::IntoIter>; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - self.spikes.into_iter() - } -} - -impl<'a, Domain, F : Num> IntoIterator for &'a DiscreteMeasure { - type Item = &'a DeltaMeasure; - type IntoIter = SpikeIter<'a, Domain, F>; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - self.spikes.iter() - } -} - -impl Sum> for DiscreteMeasure { - // Required method - fn sum(iter: I) -> Self - where - I : Iterator> - { - Self::from_iter(iter) - } -} - -impl<'a, Domain : Clone, F : Num> Sum<&'a DeltaMeasure> - for DiscreteMeasure -{ - // Required method - fn sum(iter: I) -> Self - where - I : Iterator> - { - Self::from_iter(iter.cloned()) - } -} - -impl Sum> for DiscreteMeasure { - // Required method - fn sum(iter: I) -> Self - where - I : Iterator> - { - Self::from_iter(iter.map(|μ| μ.into_iter()).flatten()) - } -} - -impl<'a, Domain : Clone, F : Num> Sum<&'a DiscreteMeasure> - for DiscreteMeasure -{ - // Required method - fn sum(iter: I) -> Self - where - I : Iterator> - { - Self::from_iter(iter.map(|μ| μ.iter_spikes()).flatten().cloned()) - } -} - -impl DiscreteMeasure { - /// Computes `μ1 ← θ * μ1 - ζ * μ2`, pruning entries where both `μ1` (`self`) and `μ2` have - // zero weight. `μ2` will contain a pruned copy of pruned original `μ1` without arithmetic - /// performed. **This expects `self` and `μ2` to have matching coordinates in each index**. - // `μ2` can be than `self`, but not longer. - pub fn pruning_sub(&mut self, θ : F, ζ : F, μ2 : &mut Self) { - for δ in &self[μ2.len()..] { - μ2.push(DeltaMeasure{ x : δ.x.clone(), α : F::ZERO}); - } - debug_assert_eq!(self.len(), μ2.len()); - let mut dest = 0; - for i in 0..self.len() { - let α = self[i].α; - let α_new = θ * α - ζ * μ2[i].α; - if dest < i { - μ2[dest] = DeltaMeasure{ x : self[i].x.clone(), α }; - self[dest] = DeltaMeasure{ x : self[i].x.clone(), α : α_new }; - } else { - μ2[i].α = α; - self[i].α = α_new; - } - dest += 1; - } - self.spikes.truncate(dest); - μ2.spikes.truncate(dest); - } -} - -impl DiscreteMeasure { - /// Prune all spikes with mass absolute value less than the given `tolerance`. - #[inline] - pub fn prune_approx(&mut self, tolerance : F) { - self.spikes.retain(|δ| δ.α.abs() > tolerance); - } -} - -impl DiscreteMeasure { - /// Extracts the masses of the spikes as a [`DVector`]. - pub fn masses_dvector(&self) -> DVector { - DVector::from_iterator(self.len(), - self.iter_masses() - .map(|α| α.to_nalgebra_mixed())) - } - - /// Sets the masses of the spikes from the values of a [`DVector`]. - pub fn set_masses_dvector(&mut self, x : &DVector) { - self.set_masses(x.iter().map(|&α| F::from_nalgebra_mixed(α))); - } - - // /// Extracts the masses of the spikes as a [`Vec`]. - // pub fn masses_vec(&self) -> Vec { - // self.iter_masses() - // .map(|α| α.to_nalgebra_mixed()) - // .collect() - // } - - // /// Sets the masses of the spikes from the values of a [`Vec`]. - // pub fn set_masses_vec(&mut self, x : &Vec) { - // self.set_masses(x.iter().map(|&α| F::from_nalgebra_mixed(α))); - // } -} - -// impl Index for DiscreteMeasure { -// type Output = DeltaMeasure; -// #[inline] -// fn index(&self, i : usize) -> &Self::Output { -// self.spikes.index(i) -// } -// } - -// impl IndexMut for DiscreteMeasure { -// #[inline] -// fn index_mut(&mut self, i : usize) -> &mut Self::Output { -// self.spikes.index_mut(i) -// } -// } - -impl< - Domain, - F : Num, - I : std::slice::SliceIndex<[DeltaMeasure]> -> Index -for DiscreteMeasure { - type Output = ]>>::Output; - #[inline] - fn index(&self, i : I) -> &Self::Output { - self.spikes.index(i) - } -} - -impl< - Domain, - F : Num, - I : std::slice::SliceIndex<[DeltaMeasure]> -> IndexMut -for DiscreteMeasure { - #[inline] - fn index_mut(&mut self, i : I) -> &mut Self::Output { - self.spikes.index_mut(i) - } -} - - -impl>, const K : usize> From<[D; K]> -for DiscreteMeasure { - #[inline] - fn from(list : [D; K]) -> Self { - list.into_iter().collect() - } -} - -impl From>> -for DiscreteMeasure { - #[inline] - fn from(spikes : Vec>) -> Self { - DiscreteMeasure{ spikes } - } -} - -impl<'a, Domain, F : Num, D> From<&'a [D]> -for DiscreteMeasure -where &'a D : Into> { - #[inline] - fn from(list : &'a [D]) -> Self { - list.into_iter().map(|d| d.into()).collect() - } -} - - -impl From> -for DiscreteMeasure { - #[inline] - fn from(δ : DeltaMeasure) -> Self { - DiscreteMeasure{ - spikes : vec!(δ) - } - } -} - -impl<'a, Domain : Clone, F : Num> From<&'a DeltaMeasure> -for DiscreteMeasure { - #[inline] - fn from(δ : &'a DeltaMeasure) -> Self { - DiscreteMeasure{ - spikes : vec!(δ.clone()) - } - } -} - - -impl>> FromIterator -for DiscreteMeasure { - #[inline] - fn from_iter(iter : T) -> Self - where T : IntoIterator { - DiscreteMeasure{ - spikes : iter.into_iter().map(|m| m.into()).collect() - } - } -} - -impl<'a, F : Num, const N : usize> TableDump<'a> -for DiscreteMeasure,F> -where DeltaMeasure, F> : Serialize + 'a { - type Iter = std::slice::Iter<'a, DeltaMeasure, F>>; - - // fn tabledump_headers(&'a self) -> Vec { - // let mut v : Vec = (0..N).map(|i| format!("x{}", i)).collect(); - // v.push("weight".into()); - // v - // } - - fn tabledump_entries(&'a self) -> Self::Iter { - // Ensure order matching the headers above - self.spikes.iter() - } -} - -// Need to manually implement serialisation for DeltaMeasure, F> [`csv`] writer fails on -// structs with nested arrays as well as with #[serde(flatten)]. -// Then derive no longer works for DiscreteMeasure -impl Serialize for DiscreteMeasure, F> -where - F: Serialize, -{ - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut s = serializer.serialize_seq(Some(self.spikes.len()))?; - for δ in self.spikes.iter() { - s.serialize_element(δ)?; - } - s.end() - } -} - -impl Measure for DiscreteMeasure { - type Domain = Domain; -} - -impl Norm for DiscreteMeasure -where DeltaMeasure : Norm { - #[inline] - fn norm(&self, _ : Radon) -> F { - self.spikes.iter().map(|m| m.norm(Radon)).sum() - } -} - -impl Mapping for DiscreteMeasure -where - Domain : Space, - G::Codomain : Sum + Mul, - G : Mapping + Clone + Space, - for<'b> &'b Domain : Instance, -{ - type Codomain = G::Codomain; - - #[inline] - fn apply>(&self, g : I) -> Self::Codomain { - g.eval(|g| self.spikes.iter().map(|m| g.apply(&m.x) * m.α).sum()) - } -} - -impl Linear for DiscreteMeasure -where - Domain : Space, - G::Codomain : Sum + Mul, - G : Mapping + Clone + Space, - for<'b> &'b Domain : Instance, -{ } - - -/// Helper trait for constructing arithmetic operations for combinations -/// of [`DiscreteMeasure`] and [`DeltaMeasure`], and their references. -trait Lift { - type Producer : Iterator>; - - #[allow(dead_code)] - /// Lifts `self` into a [`DiscreteMeasure`]. - fn lift(self) -> DiscreteMeasure; - - /// Lifts `self` into a [`DiscreteMeasure`], apply either `f` or `f_mut` whether the type - /// this method is implemented for is a reference or or not. - fn lift_with(self, - f : impl Fn(&DeltaMeasure) -> DeltaMeasure, - f_mut : impl FnMut(&mut DeltaMeasure)) - -> DiscreteMeasure; - - /// Extend `self` into a [`DiscreteMeasure`] with the spikes produced by `iter`. - fn lift_extend>>( - self, - iter : I - ) -> DiscreteMeasure; - - /// Returns an iterator for producing copies of the spikes of `self`. - fn produce(self) -> Self::Producer; -} - -impl Lift for DiscreteMeasure { - type Producer = std::vec::IntoIter>; - - #[inline] - fn lift(self) -> DiscreteMeasure { self } - - fn lift_with(mut self, - _f : impl Fn(&DeltaMeasure) -> DeltaMeasure, - f_mut : impl FnMut(&mut DeltaMeasure)) - -> DiscreteMeasure { - self.spikes.iter_mut().for_each(f_mut); - self - } - - #[inline] - fn lift_extend>>( - mut self, - iter : I - ) -> DiscreteMeasure { - self.spikes.extend(iter); - self - } - - #[inline] - fn produce(self) -> Self::Producer { - self.spikes.into_iter() - } -} - -impl<'a, F : Num, Domain : Clone> Lift for &'a DiscreteMeasure { - type Producer = MapF>, DeltaMeasure>; - - #[inline] - fn lift(self) -> DiscreteMeasure { self.clone() } - - fn lift_with(self, - f : impl Fn(&DeltaMeasure) -> DeltaMeasure, - _f_mut : impl FnMut(&mut DeltaMeasure)) - -> DiscreteMeasure { - DiscreteMeasure{ spikes : self.spikes.iter().map(f).collect() } - } - - #[inline] - fn lift_extend>>( - self, - iter : I - ) -> DiscreteMeasure { - let mut res = self.clone(); - res.spikes.extend(iter); - res - } - - #[inline] - fn produce(self) -> Self::Producer { - // TODO: maybe not optimal to clone here and would benefit from - // a reference version of lift_extend. - self.spikes.iter().mapF(Clone::clone) - } -} - -impl Lift for DeltaMeasure { - type Producer = std::iter::Once>; - - #[inline] - fn lift(self) -> DiscreteMeasure { DiscreteMeasure { spikes : vec![self] } } - - #[inline] - fn lift_with(mut self, - _f : impl Fn(&DeltaMeasure) -> DeltaMeasure, - mut f_mut : impl FnMut(&mut DeltaMeasure)) - -> DiscreteMeasure { - f_mut(&mut self); - DiscreteMeasure{ spikes : vec![self] } - } - - #[inline] - fn lift_extend>>( - self, - iter : I - ) -> DiscreteMeasure { - let mut spikes = vec![self]; - spikes.extend(iter); - DiscreteMeasure{ spikes : spikes } - } - - #[inline] - fn produce(self) -> Self::Producer { - std::iter::once(self) - } -} - -impl<'a, F : Num, Domain : Clone> Lift for &'a DeltaMeasure { - type Producer = std::iter::Once>; - - #[inline] - fn lift(self) -> DiscreteMeasure { DiscreteMeasure { spikes : vec![self.clone()] } } - - #[inline] - fn lift_with(self, - f : impl Fn(&DeltaMeasure) -> DeltaMeasure, - _f_mut : impl FnMut(&mut DeltaMeasure)) - -> DiscreteMeasure { - DiscreteMeasure{ spikes : vec![f(self)] } - } - - #[inline] - fn lift_extend>>( - self, - iter : I - ) -> DiscreteMeasure { - let mut spikes = vec![self.clone()]; - spikes.extend(iter); - DiscreteMeasure{ spikes : spikes } - } - - #[inline] - fn produce(self) -> Self::Producer { - std::iter::once(self.clone()) - } -} - -macro_rules! make_discrete_addsub_assign { - ($rhs:ty) => { - // Discrete += (&)Discrete - impl<'a, F : Num, Domain : Clone> AddAssign<$rhs> - for DiscreteMeasure { - fn add_assign(&mut self, other : $rhs) { - self.spikes.extend(other.produce()); - } - } - - impl<'a, F : Num + Neg, Domain : Clone> SubAssign<$rhs> - for DiscreteMeasure { - fn sub_assign(&mut self, other : $rhs) { - self.spikes.extend(other.produce().map(|δ| -δ)); - } - } - } -} - -make_discrete_addsub_assign!(DiscreteMeasure); -make_discrete_addsub_assign!(&'a DiscreteMeasure); -make_discrete_addsub_assign!(DeltaMeasure); -make_discrete_addsub_assign!(&'a DeltaMeasure); - -macro_rules! make_discrete_addsub { - ($lhs:ty, $rhs:ty, $alt_order:expr) => { - impl<'a, 'b, F : Num, Domain : Clone> Add<$rhs> for $lhs { - type Output = DiscreteMeasure; - fn add(self, other : $rhs) -> DiscreteMeasure { - if !$alt_order { - self.lift_extend(other.produce()) - } else { - other.lift_extend(self.produce()) - } - } - } - - impl<'a, 'b, F : Num + Neg, Domain : Clone> Sub<$rhs> for $lhs { - type Output = DiscreteMeasure; - fn sub(self, other : $rhs) -> DiscreteMeasure { - self.lift_extend(other.produce().map(|δ| -δ)) - } - } - }; -} - -make_discrete_addsub!(DiscreteMeasure, DiscreteMeasure, false); -make_discrete_addsub!(DiscreteMeasure, &'b DiscreteMeasure, false); -make_discrete_addsub!(&'a DiscreteMeasure, DiscreteMeasure, true); -make_discrete_addsub!(&'a DiscreteMeasure, &'b DiscreteMeasure, false); -make_discrete_addsub!(DeltaMeasure, DiscreteMeasure, false); -make_discrete_addsub!(DeltaMeasure, &'b DiscreteMeasure, false); -make_discrete_addsub!(&'a DeltaMeasure, DiscreteMeasure, true); -make_discrete_addsub!(&'a DeltaMeasure, &'b DiscreteMeasure, false); -make_discrete_addsub!(DiscreteMeasure, DeltaMeasure, false); -make_discrete_addsub!(DiscreteMeasure, &'b DeltaMeasure, false); -make_discrete_addsub!(&'a DiscreteMeasure, DeltaMeasure, false); -make_discrete_addsub!(&'a DiscreteMeasure, &'b DeltaMeasure, false); -make_discrete_addsub!(DeltaMeasure, DeltaMeasure, false); -make_discrete_addsub!(DeltaMeasure, &'b DeltaMeasure, false); -make_discrete_addsub!(&'a DeltaMeasure, DeltaMeasure, false); -make_discrete_addsub!(&'a DeltaMeasure, &'b DeltaMeasure, false); - -macro_rules! make_discrete_scalarop_rhs { - ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - make_discrete_scalarop_rhs!(@assign DiscreteMeasure, F, $trait_assign, $fn_assign); - make_discrete_scalarop_rhs!(@assign DiscreteMeasure, &'a F, $trait_assign, $fn_assign); - make_discrete_scalarop_rhs!(@new DiscreteMeasure, F, $trait, $fn, $fn_assign); - make_discrete_scalarop_rhs!(@new DiscreteMeasure, &'a F, $trait, $fn, $fn_assign); - make_discrete_scalarop_rhs!(@new &'b DiscreteMeasure, F, $trait, $fn, $fn_assign); - make_discrete_scalarop_rhs!(@new &'b DiscreteMeasure, &'a F, $trait, $fn, $fn_assign); - }; - - (@assign $lhs:ty, $rhs:ty, $trait_assign:ident, $fn_assign:ident) => { - impl<'a, 'b, F : Num, Domain> $trait_assign<$rhs> for $lhs { - fn $fn_assign(&mut self, b : $rhs) { - self.spikes.iter_mut().for_each(|δ| δ.$fn_assign(b)); - } - } - }; - (@new $lhs:ty, $rhs:ty, $trait:ident, $fn:ident, $fn_assign:ident) => { - impl<'a, 'b, F : Num, Domain : Clone> $trait<$rhs> for $lhs { - type Output = DiscreteMeasure; - fn $fn(self, b : $rhs) -> Self::Output { - self.lift_with(|δ| δ.$fn(b), |δ| δ.$fn_assign(b)) - } - } - }; -} - -make_discrete_scalarop_rhs!(Mul, mul, MulAssign, mul_assign); -make_discrete_scalarop_rhs!(Div, div, DivAssign, div_assign); - -macro_rules! make_discrete_unary { - ($trait:ident, $fn:ident, $type:ty) => { - impl<'a, F : Num + Neg, Domain : Clone> Neg for $type { - type Output = DiscreteMeasure; - fn $fn(self) -> Self::Output { - self.lift_with(|δ| δ.$fn(), |δ| δ.α = δ.α.$fn()) - } - } - } -} - -make_discrete_unary!(Neg, neg, DiscreteMeasure); -make_discrete_unary!(Neg, neg, &'a DiscreteMeasure); - -// impl Neg for DiscreteMeasure { -// type Output = Self; -// fn $fn(mut self, b : F) -> Self { -// self.lift().spikes.iter_mut().for_each(|δ| δ.neg(b)); -// self -// } -// } - -macro_rules! make_discrete_scalarop_lhs { - ($trait:ident, $fn:ident; $($f:ident)+) => { $( - impl $trait> for $f { - type Output = DiscreteMeasure; - fn $fn(self, mut v : DiscreteMeasure) -> Self::Output { - v.spikes.iter_mut().for_each(|δ| δ.α = self.$fn(δ.α)); - v - } - } - - impl<'a, Domain : Copy> $trait<&'a DiscreteMeasure> for $f { - type Output = DiscreteMeasure; - fn $fn(self, v : &'a DiscreteMeasure) -> Self::Output { - DiscreteMeasure{ - spikes : v.spikes.iter().map(|δ| self.$fn(δ)).collect() - } - } - } - - impl<'b, Domain> $trait> for &'b $f { - type Output = DiscreteMeasure; - fn $fn(self, mut v : DiscreteMeasure) -> Self::Output { - v.spikes.iter_mut().for_each(|δ| δ.α = self.$fn(δ.α)); - v - } - } - - impl<'a, 'b, Domain : Copy> $trait<&'a DiscreteMeasure> for &'b $f { - type Output = DiscreteMeasure; - fn $fn(self, v : &'a DiscreteMeasure) -> Self::Output { - DiscreteMeasure{ - spikes : v.spikes.iter().map(|δ| self.$fn(δ)).collect() - } - } - } - )+ } -} - -make_discrete_scalarop_lhs!(Mul, mul; f32 f64 i8 i16 i32 i64 isize u8 u16 u32 u64 usize); -make_discrete_scalarop_lhs!(Div, div; f32 f64 i8 i16 i32 i64 isize u8 u16 u32 u64 usize); - -impl Collection for DiscreteMeasure { - type Element = DeltaMeasure; - type RefsIter<'a> = std::slice::Iter<'a, Self::Element> where Self : 'a; - - #[inline] - fn iter_refs(&self) -> Self::RefsIter<'_> { - self.iter_spikes() - } -} - -impl Space for DiscreteMeasure { - type Decomp = MeasureDecomp; -} - -pub type SpikeSlice<'b, Domain, F> = &'b [DeltaMeasure]; - -pub type EitherSlice<'b, Domain, F> = EitherDecomp< - Vec>, - SpikeSlice<'b, Domain, F> ->; - -impl Decomposition> for MeasureDecomp { - type Decomposition<'b> = EitherSlice<'b, Domain, F> where DiscreteMeasure : 'b; - type Reference<'b> = SpikeSlice<'b, Domain, F> where DiscreteMeasure : 'b; - - /// Left the lightweight reference type into a full decomposition type. - fn lift<'b>(r : Self::Reference<'b>) -> Self::Decomposition<'b> { - EitherDecomp::Borrowed(r) - } -} - -impl Instance, MeasureDecomp> -for DiscreteMeasure -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - EitherDecomp::Owned(self.spikes) - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - self.spikes.as_slice() - } - - fn cow<'b>(self) -> MyCow<'b, DiscreteMeasure> where Self : 'b { - MyCow::Owned(self) - } - - fn own(self) -> DiscreteMeasure { - self - } -} - -impl<'a, F : Num, Domain : Clone> Instance, MeasureDecomp> -for &'a DiscreteMeasure -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - EitherDecomp::Borrowed(self.spikes.as_slice()) - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - self.spikes.as_slice() - } - - fn cow<'b>(self) -> MyCow<'b, DiscreteMeasure> where Self : 'b { - MyCow::Borrowed(self) - } - - fn own(self) -> DiscreteMeasure { - self.clone() - } -} - -impl<'a, F : Num, Domain : Clone> Instance, MeasureDecomp> -for EitherSlice<'a, Domain, F> -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - self - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - match self { - EitherDecomp::Owned(v) => v.as_slice(), - EitherDecomp::Borrowed(s) => s, - } - } - - fn own(self) -> DiscreteMeasure { - match self { - EitherDecomp::Owned(v) => v.into(), - EitherDecomp::Borrowed(s) => s.into(), - } - } -} - -impl<'a, F : Num, Domain : Clone> Instance, MeasureDecomp> -for &'a EitherSlice<'a, Domain, F> -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - match self { - EitherDecomp::Owned(v) => EitherDecomp::Borrowed(v.as_slice()), - EitherDecomp::Borrowed(s) => EitherDecomp::Borrowed(s), - } - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - match self { - EitherDecomp::Owned(v) => v.as_slice(), - EitherDecomp::Borrowed(s) => s, - } - } - - fn own(self) -> DiscreteMeasure { - match self { - EitherDecomp::Owned(v) => v.as_slice(), - EitherDecomp::Borrowed(s) => s - }.into() - } -} - -impl<'a, F : Num, Domain : Clone> Instance, MeasureDecomp> -for SpikeSlice<'a, Domain, F> -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - EitherDecomp::Borrowed(self) - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - self - } - - fn own(self) -> DiscreteMeasure { - self.into() - } -} - -impl<'a, F : Num, Domain : Clone> Instance, MeasureDecomp> -for &'a SpikeSlice<'a, Domain, F> -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - EitherDecomp::Borrowed(*self) - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - *self - } - - fn own(self) -> DiscreteMeasure { - (*self).into() - } -} - -impl Instance, MeasureDecomp> -for DeltaMeasure -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - EitherDecomp::Owned(vec![self]) - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - std::slice::from_ref(self) - } - - fn own(self) -> DiscreteMeasure { - self.into() - } -} - -impl<'a, F : Num, Domain : Clone> Instance, MeasureDecomp> -for &'a DeltaMeasure -{ - fn decompose<'b>(self) - -> >>::Decomposition<'b> - where Self : 'b, DiscreteMeasure : 'b { - EitherDecomp::Borrowed(std::slice::from_ref(self)) - } - - fn ref_instance(&self) - -> >>::Reference<'_> - { - std::slice::from_ref(*self) - } - - fn own(self) -> DiscreteMeasure { - self.into() - } -} diff -r 9738b51d90d7 -r 4f468d35fa29 src/measures/merging.rs --- a/src/measures/merging.rs Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,333 +0,0 @@ -/*! -Spike merging heuristics for [`DiscreteMeasure`]s. - -This module primarily provides the [`SpikeMerging`] trait, and within it, -the [`SpikeMerging::merge_spikes`] method. The trait is implemented on -[`DiscreteMeasure, F>`]s in dimensions `N=1` and `N=2`. -*/ - -use numeric_literals::replace_float_literals; -use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -//use clap::builder::{PossibleValuesParser, PossibleValue}; -use alg_tools::nanleast::NaNLeast; - -use super::delta::*; -use super::discrete::*; -use crate::types::*; - -/// Spike merging heuristic selection -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[allow(dead_code)] -pub struct SpikeMergingMethod { - // Merging radius - pub(crate) radius: F, - // Enabled - pub(crate) enabled: bool, - // Interpolate merged points - pub(crate) interp: bool, -} - -#[replace_float_literals(F::cast_from(literal))] -impl Default for SpikeMergingMethod { - fn default() -> Self { - SpikeMergingMethod { - radius: 0.01, - enabled: false, - interp: true, - } - } -} - -/// Trait for dimension-dependent implementation of heuristic peak merging strategies. -pub trait SpikeMerging { - /// Attempt spike merging according to [`SpikeMerging`] method. - /// - /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure - /// `accept` if any merging is performed. The closure should accept as its only parameter a - /// new candidate measure (it will generally be internally mutated `self`, although this is - /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of - /// an arbitrary value. This method will return that value for the *last* accepted merge, or - /// [`None`] if no merge was accepted. - /// - /// This method is stable with respect to spike locations: on merge, the weights of existing - /// removed spikes is set to zero, new ones inserted at the end of the spike vector. - /// They merge may also be performed by increasing the weights of the existing spikes, - /// without inserting new spikes. - fn merge_spikes(&mut self, method: SpikeMergingMethod, accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool, - { - if method.enabled { - self.do_merge_spikes_radius(method.radius, method.interp, accept) - } else { - 0 - } - } - - /// Attempt to merge spikes based on a value and a fitness function. - /// - /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of - /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` - // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial - /// `self` is returned. also the number of merges is returned; - fn merge_spikes_fitness( - &mut self, - method: SpikeMergingMethod, - value: G, - fitness: H, - ) -> (V, usize) - where - G: Fn(&'_ Self) -> V, - H: Fn(&'_ V) -> O, - O: PartialOrd, - { - let mut res = value(self); - let initial_fitness = fitness(&res); - let count = self.merge_spikes(method, |μ| { - res = value(μ); - fitness(&res) <= initial_fitness - }); - (res, count) - } - - /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). - /// - /// This method implements [`SpikeMerging::merge_spikes`]. - fn do_merge_spikes_radius(&mut self, ρ: F, interp: bool, accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool; -} - -#[replace_float_literals(F::cast_from(literal))] -impl DiscreteMeasure, F> { - /// Attempts to merge spikes with indices `i` and `j`. - /// - /// This assumes that the weights of the two spikes have already been checked not to be zero. - /// - /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. - /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its - /// return value. - /// - /// Returns the index of `self.spikes` storing the new spike. - fn attempt_merge( - &mut self, - i: usize, - j: usize, - interp: bool, - accept: &mut G, - ) -> Option - where - G: FnMut(&'_ Self) -> bool, - { - let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i]; - let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j]; - - if interp { - // Merge inplace - self.spikes[i].α = 0.0; - self.spikes[j].α = 0.0; - let αia = αi.abs(); - let αja = αj.abs(); - self.spikes.push(DeltaMeasure { - α: αi + αj, - x: (xi * αia + xj * αja) / (αia + αja), - }); - if accept(self) { - Some(self.spikes.len() - 1) - } else { - // Merge not accepted, restore modification - self.spikes[i].α = αi; - self.spikes[j].α = αj; - self.spikes.pop(); - None - } - } else { - // Attempt merge inplace, first combination - self.spikes[i].α = αi + αj; - self.spikes[j].α = 0.0; - if accept(self) { - // Merge accepted - Some(i) - } else { - // Attempt merge inplace, second combination - self.spikes[i].α = 0.0; - self.spikes[j].α = αi + αj; - if accept(self) { - // Merge accepted - Some(j) - } else { - // Merge not accepted, restore modification - self.spikes[i].α = αi; - self.spikes[j].α = αj; - None - } - } - } - } -} - -/// Sorts a vector of indices into `slice` by `compare`. -/// -/// The closure `compare` operators on references to elements of `slice`. -/// Returns the sorted vector of indices into `slice`. -pub fn sort_indices_by(slice: &[V], mut compare: F) -> Vec -where - F: FnMut(&V, &V) -> Ordering, -{ - let mut indices = Vec::from_iter(0..slice.len()); - indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); - indices -} - -#[replace_float_literals(F::cast_from(literal))] -impl SpikeMerging for DiscreteMeasure, F> { - fn do_merge_spikes_radius(&mut self, ρ: F, interp: bool, mut accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool, - { - // Sort by coordinate into an indexing array. - let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { - let &Loc([x1]) = &δ1.x; - let &Loc([x2]) = &δ2.x; - // nan-ignoring ordering of floats - NaNLeast(x1).cmp(&NaNLeast(x2)) - }); - - // Initialise result - let mut count = 0; - - // Scan consecutive pairs and merge if close enough and accepted by `accept`. - if indices.len() == 0 { - return count; - } - for k in 0..(indices.len() - 1) { - let i = indices[k]; - let j = indices[k + 1]; - let &DeltaMeasure { - x: Loc([xi]), - α: αi, - } = &self.spikes[i]; - let &DeltaMeasure { - x: Loc([xj]), - α: αj, - } = &self.spikes[j]; - debug_assert!(xi <= xj); - // If close enough, attempt merging - if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { - if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) { - // For this to work (the debug_assert! to not trigger above), the new - // coordinate produced by attempt_merge has to be at most xj. - indices[k + 1] = l; - count += 1 - } - } - } - - count - } -} - -/// Orders `δ1` and `δ1` according to the first coordinate. -fn compare_first_coordinate( - δ1: &DeltaMeasure, F>, - δ2: &DeltaMeasure, F>, -) -> Ordering { - let &Loc([x11, ..]) = &δ1.x; - let &Loc([x21, ..]) = &δ2.x; - // nan-ignoring ordering of floats - NaNLeast(x11).cmp(&NaNLeast(x21)) -} - -#[replace_float_literals(F::cast_from(literal))] -impl SpikeMerging for DiscreteMeasure, F> { - fn do_merge_spikes_radius(&mut self, ρ: F, interp: bool, mut accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool, - { - // Sort by first coordinate into an indexing array. - let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); - - // Initialise result - let mut count = 0; - let mut start_scan_2nd = 0; - - // Scan in order - if indices.len() == 0 { - return count; - } - for k in 0..indices.len() - 1 { - let i = indices[k]; - let &DeltaMeasure { - x: Loc([xi1, xi2]), - α: αi, - } = &self[i]; - - if αi == 0.0 { - // Nothin to be done if the weight is already zero - continue; - } - - let mut closest = None; - - // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` - // the smallest invalid merging index on the previous loop iteration, because a - // the _closest_ mergeable spike might have index less than `k` in `indices`, and a - // merge with it might have not been attempted with this spike if a different closer - // spike was discovered based on the second coordinate. - 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() { - if l == k { - // Do not attempt to merge a spike with itself - continue; - } - let j = indices[l]; - let &DeltaMeasure { - x: Loc([xj1, xj2]), - α: αj, - } = &self[j]; - - if xj1 < xi1 - ρ { - // Spike `j = indices[l]` has too low first coordinate. Update starting index - // for next iteration, and continue scanning. - start_scan_2nd = l; - continue 'scan_2nd; - } else if xj1 > xi1 + ρ { - // Break out: spike `j = indices[l]` has already too high first coordinate, no - // more close enough spikes can be found due to the sorting of `indices`. - break 'scan_2nd; - } - - // If also second coordinate is close enough, attempt merging if closer than - // previously discovered mergeable spikes. - let d2 = (xi2 - xj2).abs(); - if αj != 0.0 && d2 <= ρ { - let r1 = xi1 - xj1; - let d = (d2 * d2 + r1 * r1).sqrt(); - match closest { - None => closest = Some((l, j, d)), - Some((_, _, r)) if r > d => closest = Some((l, j, d)), - _ => {} - } - } - } - - // Attempt merging closest close-enough spike - if let Some((l, j, _)) = closest { - if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) { - // If merge was succesfull, make new spike candidate for merging. - indices[l] = n; - count += 1; - let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]); - // Re-sort relevant range of indices - if l < k { - indices[l..k].sort_by(|&i, &j| compare(i, j)); - } else { - indices[k + 1..=l].sort_by(|&i, &j| compare(i, j)); - } - } - } - } - - count - } -} diff -r 9738b51d90d7 -r 4f468d35fa29 src/pdps.rs --- a/src/pdps.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/pdps.rs Thu Feb 26 11:38:43 2026 -0500 @@ -38,50 +38,25 @@

*/ -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use nalgebra::DVector; -use clap::ValueEnum; - -use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Mapping; -use alg_tools::norms::{ - Linfinity, - Projection, -}; -use alg_tools::mapping::{RealMapping, Instance}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::linops::AXPY; - -use crate::types::*; -use crate::measures::{DiscreteMeasure, RNDM}; +use crate::fb::{postprocess, prune_with_stats}; +use crate::forward_model::ForwardModel; use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductBoundedBy, -}; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::fb::{ - postprocess, - prune_with_stats -}; -pub use crate::prox_penalty::{ - FBGenericConfig, - ProxPenalty -}; +use crate::measures::merging::SpikeMergingMethod; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::Plotter; +pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBoundPD}; use crate::regularisation::RegTerm; -use crate::dataterm::{ - DataTerm, - L2Squared, - L1 -}; -use crate::measures::merging::SpikeMergingMethod; - +use crate::types::*; +use alg_tools::convex::{Conjugable, ConvexMapping, Prox}; +use alg_tools::error::DynResult; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::linops::{Mapping, AXPY}; +use alg_tools::mapping::{DataTerm, Instance}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use anyhow::ensure; +use clap::ValueEnum; +use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; /// Acceleration #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] @@ -93,15 +68,18 @@ #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")] Partial, /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed - #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] - Full + #[clap( + name = "full", + help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed" + )] + Full, } #[replace_float_literals(F::cast_from(literal))] impl Acceleration { /// PDPS parameter acceleration. Updates τ and σ and returns ω. /// This uses dual strong convexity, not primal. - fn accelerate(self, τ : &mut F, σ : &mut F, γ : F) -> F { + fn accelerate(self, τ: &mut F, σ: &mut F, γ: F) -> F { match self { Acceleration::None => 1.0, Acceleration::Partial => { @@ -109,13 +87,13 @@ *σ *= ω; *τ /= ω; ω - }, + } Acceleration::Full => { let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); *σ *= ω; *τ /= ω; ω - }, + } } } } @@ -123,91 +101,35 @@ /// Settings for [`pointsource_pdps_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct PDPSConfig { +pub struct PDPSConfig { /// Primal step length scaling. We must have `τ0 * σ0 < 1`. - pub τ0 : F, + pub τ0: F, /// Dual step length scaling. We must have `τ0 * σ0 < 1`. - pub σ0 : F, + pub σ0: F, /// Accelerate if available - pub acceleration : Acceleration, + pub acceleration: Acceleration, /// Generic parameters - pub generic : FBGenericConfig, + pub generic: InsertionConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for PDPSConfig { +impl Default for PDPSConfig { fn default() -> Self { let τ0 = 5.0; PDPSConfig { τ0, - σ0 : 0.99/τ0, - acceleration : Acceleration::Partial, - generic : FBGenericConfig { - merging : SpikeMergingMethod { enabled : true, ..Default::default() }, - .. Default::default() + σ0: 0.99 / τ0, + acceleration: Acceleration::Partial, + generic: InsertionConfig { + merging: SpikeMergingMethod { enabled: true, ..Default::default() }, + ..Default::default() }, } } } -/// Trait for data terms for the PDPS -#[replace_float_literals(F::cast_from(literal))] -pub trait PDPSDataTerm : DataTerm { - /// Calculate some subdifferential at `x` for the conjugate - fn some_subdifferential(&self, x : V) -> V; - - /// Factor of strong convexity of the conjugate - #[inline] - fn factor_of_strong_convexity(&self) -> F { - 0.0 - } - - /// Perform dual update - fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); -} - - -#[replace_float_literals(F::cast_from(literal))] -impl PDPSDataTerm -for L2Squared -where - F : Float, - V : Euclidean + AXPY, - for<'b> &'b V : Instance, -{ - fn some_subdifferential(&self, x : V) -> V { x } - - fn factor_of_strong_convexity(&self) -> F { - 1.0 - } - - #[inline] - fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { - y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); - } -} - -#[replace_float_literals(F::cast_from(literal))] -impl -PDPSDataTerm, N> -for L1 { - fn some_subdifferential(&self, mut x : DVector) -> DVector { - // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. - x.iter_mut() - .for_each(|v| if *v != F::ZERO { *v = *v/::abs(*v) }); - x - } - - #[inline] - fn dual_update(&self, y : &mut DVector, y_prev : &DVector, σ : F) { - y.axpy(1.0, y_prev, σ); - y.proj_ball_mut(1.0, Linfinity); - } -} - /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. /// -/// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared. /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution @@ -218,42 +140,44 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_pdps_reg( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - pdpsconfig : &PDPSConfig, - iterator : I, - mut plotter : SeqPlotter, - dataterm : D, -) -> RNDM +pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>( + f: &'a DataTerm, A, Phi>, + reg: &Reg, + prox_penalty: &P, + pdpsconfig: &PDPSConfig, + iterator: I, + mut plotter: Plot, + μ0 : Option>, +) -> DynResult> where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - A : ForwardModel, F> - + AdjointProductBoundedBy, P, FloatType=F>, - A::PreadjointCodomain : RealMapping, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - PlotLookup : Plotting, - RNDM : SpikeMerging, - D : PDPSDataTerm, - Reg : RegTerm, - P : ProxPenalty, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F>, + for<'b> &'b A::Observable: Instance, + A::Observable: AXPY, + RNDM: SpikeMerging, + Reg: RegTerm, F>, + Phi: Conjugable, + for<'b> Phi::Conjugate<'b>: Prox, + P: ProxPenalty, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD>, + Plot: Plotter>, { + // Check parameters + ensure!( + pdpsconfig.τ0 > 0.0 && pdpsconfig.σ0 > 0.0 && pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, + "Invalid step length parameters" + ); - // Check parameters - assert!(pdpsconfig.τ0 > 0.0 && - pdpsconfig.σ0 > 0.0 && - pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, - "Invalid step length parameters"); + let opA = f.operator(); + let b = f.data(); + let phistar = f.fidelity().conjugate(); // Set up parameters let config = &pdpsconfig.generic; - let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); + let l = prox_penalty.step_length_bound_pd(opA)?; let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; - let γ = dataterm.factor_of_strong_convexity(); + let γ = phistar.factor_of_strong_convexity(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -261,38 +185,35 @@ let mut ε = tolerance.initial(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut y = dataterm.some_subdifferential(-b); - let mut y_prev = y.clone(); - let full_stats = |μ : &RNDM, ε, stats| IterInfo { - value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), - n_spikes : μ.len(), + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); + let mut y = f.residual(&μ); + let full_stats = |μ: &RNDM, ε, stats| IterInfo { + value: f.apply(μ) + reg.apply(μ), + n_spikes: μ.len(), ε, // postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(y * τ); + // FIXME: the clone is required to avoid compiler overflows with reference-Mul requirement above. + let mut τv = opA.preadjoint().apply(y.clone() * τ); // Save current base point let μ_base = μ.clone(); - + // Insert and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, - config, ®, &state, &mut stats - ); + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, + )?; // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += prox_penalty.merge_spikes_no_fitness( - &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, - ); + stats.merged += prox_penalty + .merge_spikes_no_fitness(&mut μ, &mut τv, &μ_base, None, τ, ε, config, ®); } stats.pruned += prune_with_stats(&mut μ); @@ -300,11 +221,13 @@ let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); // Do dual update - y = b.clone(); // y = b - opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b - opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b - dataterm.dual_update(&mut y, &y_prev, σ); - y_prev.copy_from(&y); + // y = y_prev + τb + y.axpy(τ, b, 1.0); + // y = y_prev - τ(A[(1+ω)μ^{k+1}]-b) + opA.gemv(&mut y, -τ * (1.0 + ω), &μ, 1.0); + // y = y_prev - τ(A[(1+ω)μ^{k+1} - ω μ^k]-b) + opA.gemv(&mut y, τ * ω, &μ_base, 1.0); + y = phistar.prox(τ, y); // Give statistics if requested let iter = state.iteration(); @@ -318,6 +241,5 @@ ε = tolerance.update(ε, iter); } - postprocess(μ, config, dataterm, opA, b) + postprocess(μ, config, |μ| f.apply(μ)) } - diff -r 9738b51d90d7 -r 4f468d35fa29 src/plot.rs --- a/src/plot.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/plot.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,37 +1,30 @@ //! Plotting helper utilities +use crate::measures::*; +use alg_tools::lingrid::LinGrid; +use alg_tools::loc::Loc; +use alg_tools::mapping::RealMapping; +use alg_tools::tabledump::write_csv; +use alg_tools::types::*; use numeric_literals::replace_float_literals; use serde::Serialize; -use alg_tools::types::*; -use alg_tools::lingrid::LinGrid; -use alg_tools::mapping::RealMapping; -use alg_tools::loc::Loc; -use alg_tools::tabledump::write_csv; -use crate::measures::*; /// Helper trait for implementing dimension-dependent plotting routines. -pub trait Plotting { +pub trait Plotting { /// Plot several mappings and a discrete measure into a file. - fn plot_into_file_spikes< - F : Float, - T1 : RealMapping, - T2 : RealMapping - > ( - g : Option<&T1>, - ω : Option<&T2>, - grid : LinGrid, - μ : &RNDM, - filename : String, + fn plot_into_file_spikes, T2: RealMapping>( + g: Option<&T1>, + ω: Option<&T2>, + grid: LinGrid, + μ: &RNDM, + filename: String, ); /// Plot a mapping into a file, sampling values on a given grid. - fn plot_into_file< - F : Float, - T1 : RealMapping, - > ( - g : &T1, - grid : LinGrid, - filename : String, + fn plot_into_file>( + g: &T1, + grid: LinGrid, + filename: String, ); } @@ -39,172 +32,181 @@ pub struct PlotLookup; #[derive(Serialize)] -struct CSVHelper1 { - x : F, - f : F, +struct CSVHelper1 { + x: F, + f: F, } #[derive(Serialize)] -struct CSVHelper1_2{ - x : F, - g : Option, - omega : Option +struct CSVHelper1_2 { + x: F, + g: Option, + omega: Option, } #[derive(Serialize)] -struct CSVSpike1 { - x : F, - alpha : F, +struct CSVSpike1 { + x: F, + alpha: F, } impl Plotting<1> for PlotLookup { - fn plot_into_file_spikes< - F : Float, - T1 : RealMapping, - T2 : RealMapping - > ( - g0 : Option<&T1>, - ω0 : Option<&T2>, - grid : LinGrid, - μ : &DiscreteMeasure, F>, - filename : String, + fn plot_into_file_spikes, T2: RealMapping<1, F>>( + g0: Option<&T1>, + ω0: Option<&T2>, + grid: LinGrid<1, F>, + μ: &DiscreteMeasure, F>, + filename: String, ) { - let data = grid.into_iter().map(|p@Loc([x]) : Loc| CSVHelper1_2 { - x, - g : g0.map(|g| g.apply(&p)), - omega : ω0.map(|ω| ω.apply(&p)) - }); + let data = grid + .into_iter() + .map(|p @ Loc([x]): Loc<1, F>| CSVHelper1_2 { + x, + g: g0.map(|g| g.apply(&p)), + omega: ω0.map(|ω| ω.apply(&p)), + }); let csv_f = format!("{}_functions.csv", filename); write_csv(data, csv_f).expect("CSV save error"); let spikes = μ.iter_spikes().map(|δ| { let Loc([x]) = δ.x; - CSVSpike1 { x, alpha : δ.α } + CSVSpike1 { x, alpha: δ.α } }); let csv_f = format!("{}_spikes.csv", filename); write_csv(spikes, csv_f).expect("CSV save error"); } - fn plot_into_file< - F : Float, - T1 : RealMapping, - > ( - g : &T1, - grid : LinGrid, - filename : String, + fn plot_into_file>( + g: &T1, + grid: LinGrid<1, F>, + filename: String, ) { - let data = grid.into_iter().map(|p@Loc([x]) : Loc| CSVHelper1 { - x, - f : g.apply(&p), - }); + let data = grid + .into_iter() + .map(|p @ Loc([x]): Loc<1, F>| CSVHelper1 { x, f: g.apply(&p) }); let csv_f = format!("{}.txt", filename); write_csv(data, csv_f).expect("CSV save error"); } - } #[derive(Serialize)] -struct CSVHelper2 { - x : F, - y : F, - f : F, +struct CSVHelper2 { + x: F, + y: F, + f: F, } #[derive(Serialize)] -struct CSVHelper2_2{ - x : F, - y : F, - g : Option, - omega : Option +struct CSVHelper2_2 { + x: F, + y: F, + g: Option, + omega: Option, } #[derive(Serialize)] -struct CSVSpike2 { - x : F, - y : F, - alpha : F, +struct CSVSpike2 { + x: F, + y: F, + alpha: F, } - impl Plotting<2> for PlotLookup { #[replace_float_literals(F::cast_from(literal))] - fn plot_into_file_spikes< - F : Float, - T1 : RealMapping, - T2 : RealMapping - > ( - g0 : Option<&T1>, - ω0 : Option<&T2>, - grid : LinGrid, - μ : &DiscreteMeasure, F>, - filename : String, + fn plot_into_file_spikes, T2: RealMapping<2, F>>( + g0: Option<&T1>, + ω0: Option<&T2>, + grid: LinGrid<2, F>, + μ: &DiscreteMeasure, F>, + filename: String, ) { - let data = grid.into_iter().map(|p@Loc([x, y]) : Loc| CSVHelper2_2 { - x, - y, - g : g0.map(|g| g.apply(&p)), - omega : ω0.map(|ω| ω.apply(&p)) - }); + let data = grid + .into_iter() + .map(|p @ Loc([x, y]): Loc<2, F>| CSVHelper2_2 { + x, + y, + g: g0.map(|g| g.apply(&p)), + omega: ω0.map(|ω| ω.apply(&p)), + }); let csv_f = format!("{}_functions.csv", filename); write_csv(data, csv_f).expect("CSV save error"); let spikes = μ.iter_spikes().map(|δ| { let Loc([x, y]) = δ.x; - CSVSpike2 { x, y, alpha : δ.α } + CSVSpike2 { x, y, alpha: δ.α } }); let csv_f = format!("{}_spikes.csv", filename); write_csv(spikes, csv_f).expect("CSV save error"); } - fn plot_into_file< - F : Float, - T1 : RealMapping, - > ( - g : &T1, - grid : LinGrid, - filename : String, + fn plot_into_file>( + g: &T1, + grid: LinGrid<2, F>, + filename: String, ) { - let data = grid.into_iter().map(|p@Loc([x, y]) : Loc| CSVHelper2 { - x, - y, - f : g.apply(&p), - }); + let data = grid + .into_iter() + .map(|p @ Loc([x, y]): Loc<2, F>| CSVHelper2 { + x, + y, + f: g.apply(&p), + }); let csv_f = format!("{}.txt", filename); write_csv(data, csv_f).expect("CSV save error"); } - } -/// A helper structure for plotting a sequence of images. -#[derive(Clone,Debug)] -pub struct SeqPlotter { - /// File name prefix - prefix : String, - /// Maximum number of plots to perform - max_plots : usize, - /// Sampling grid - grid : LinGrid, - /// Current plot count - plot_count : usize, +/// Trait for plotters +pub trait Plotter { + /// Plot the functions `g` and `ω` as well as the spikes of `μ`. + fn plot_spikes(&mut self, iter: usize, g: Option<&T1>, ω: Option<&T2>, μ: &M); +} + +/// A plotter that does nothing. +pub struct NoPlotter; + +impl Plotter for NoPlotter { + fn plot_spikes(&mut self, _iter: usize, _g: Option<&T1>, _ω: Option<&T2>, _μ: &M) {} } -impl SeqPlotter -where PlotLookup : Plotting { - /// Creates a new sequence plotter instance - pub fn new(prefix : String, max_plots : usize, grid : LinGrid) -> Self { - SeqPlotter { prefix, max_plots, grid, plot_count : 0 } - } +/// A basic plotter. +/// +/// This calls [`PlotLookup::plot_into_file_spikes`] with a sequentially numbered file name. +#[derive(Clone, Debug)] +pub struct SeqPlotter { + /// File name prefix + prefix: String, + /// Maximum number of plots to perform + max_plots: usize, + /// Sampling grid + grid: LinGrid, + /// Current plot count + plot_count: usize, +} - /// This calls [`PlotLookup::plot_into_file_spikes`] with a sequentially numbered file name. - pub fn plot_spikes( - &mut self, - iter : usize, - g : Option<&T1>, - ω : Option<&T2>, - μ : &RNDM, - ) where T1 : RealMapping, - T2 : RealMapping - { +impl SeqPlotter +where + PlotLookup: Plotting, +{ + /// Creates a new sequence plotter instance + pub fn new(prefix: String, max_plots: usize, grid: LinGrid) -> Self { + SeqPlotter { + prefix, + max_plots, + grid, + plot_count: 0, + } + } +} + +impl Plotter> for SeqPlotter +where + F: Float, + T1: RealMapping, + T2: RealMapping, + PlotLookup: Plotting, +{ + fn plot_spikes(&mut self, iter: usize, g: Option<&T1>, ω: Option<&T2>, μ: &RNDM) { if self.plot_count == 0 && self.max_plots > 0 { std::fs::create_dir_all(&self.prefix).expect("Unable to create plot directory"); } @@ -214,7 +216,7 @@ ω, self.grid, μ, - format!("{}out{:03}", self.prefix, iter) + format!("{}out{:03}", self.prefix, iter), ); self.plot_count += 1; } diff -r 9738b51d90d7 -r 4f468d35fa29 src/preadjoint_helper.rs --- a/src/preadjoint_helper.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/preadjoint_helper.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,36 +2,42 @@ Preadjoint construction helper */ -use std::marker::PhantomData; +use alg_tools::error::DynResult; +pub use alg_tools::linops::*; +use alg_tools::norms::{HasDualExponent, Norm}; use alg_tools::types::*; -pub use alg_tools::linops::*; -use alg_tools::norms::{Norm, HasDualExponent}; +use std::marker::PhantomData; /// Helper structure for constructing preadjoints of `S` where `S : Linear`. /// [`Linear`] needs to be implemented for each instance, but [`Adjointable`] /// and [`BoundedLinear`] have blanket implementations. -#[derive(Clone,Debug)] -pub struct PreadjointHelper<'a, S : 'a, X> { - pub forward_op : &'a S, - _domain : PhantomData +#[derive(Clone, Debug)] +pub struct PreadjointHelper<'a, S: 'a, X> { + pub forward_op: &'a S, + _domain: PhantomData, } -impl<'a, S : 'a, X> PreadjointHelper<'a, S, X> { - pub fn new(forward_op : &'a S) -> Self { - PreadjointHelper { forward_op, _domain: PhantomData } +impl<'a, S: 'a, X> PreadjointHelper<'a, S, X> { + pub fn new(forward_op: &'a S) -> Self { + PreadjointHelper { + forward_op, + _domain: PhantomData, + } } } -impl<'a, X, Ypre, S> Adjointable -for PreadjointHelper<'a, S, X> +impl<'a, X, Ypre, S> Adjointable for PreadjointHelper<'a, S, X> where - X : Space, - Ypre : Space, - Self : Linear, - S : Clone + Linear + X: Space, + Ypre: Space, + Self: Linear, + S: Clone + Linear, { type AdjointCodomain = S::Codomain; - type Adjoint<'b> = S where Self : 'b; + type Adjoint<'b> + = S + where + Self: 'b; fn adjoint(&self) -> Self::Adjoint<'_> { self.forward_op.clone() @@ -39,17 +45,18 @@ } impl<'a, F, X, Ypre, ExpXpre, ExpYpre, S> BoundedLinear -for PreadjointHelper<'a, S, X> + for PreadjointHelper<'a, S, X> where - ExpXpre : HasDualExponent, - ExpYpre : HasDualExponent, - F : Float, - X : Space + Norm, - Ypre : Space + Norm, - Self : Linear, - S : 'a + Clone + BoundedLinear + ExpXpre: HasDualExponent, + ExpYpre: HasDualExponent, + F: Float, + X: Space + Norm, + Ypre: Space + Norm, + Self: Linear, + S: 'a + Clone + BoundedLinear, { - fn opnorm_bound(&self, expy : ExpYpre, expx : ExpXpre) -> F { - self.forward_op.opnorm_bound(expx.dual_exponent(), expy.dual_exponent()) + fn opnorm_bound(&self, expy: ExpYpre, expx: ExpXpre) -> DynResult { + self.forward_op + .opnorm_bound(expx.dual_exponent(), expy.dual_exponent()) } } diff -r 9738b51d90d7 -r 4f468d35fa29 src/prox_penalty.rs --- a/src/prox_penalty.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/prox_penalty.rs Thu Feb 26 11:38:43 2026 -0500 @@ -7,13 +7,15 @@ use serde::{Deserialize, Serialize}; use crate::measures::merging::SpikeMergingMethod; -use crate::measures::RNDM; +use crate::measures::DiscreteMeasure; use crate::regularisation::RegTerm; use crate::subproblem::InnerSettings; use crate::tolerance::Tolerance; use crate::types::{IterInfo, RefinementSettings}; +use alg_tools::error::DynResult; +use alg_tools::instance::Space; use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; -use alg_tools::mapping::RealMapping; +use alg_tools::mapping::Mapping; use alg_tools::nalgebra_support::ToNalgebraRealField; pub mod radon_squared; @@ -23,7 +25,7 @@ /// Settings for the solution of the stepwise optimality condition. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct FBGenericConfig { +pub struct InsertionConfig { /// Tolerance for point insertion. pub tolerance: Tolerance, @@ -68,9 +70,9 @@ } #[replace_float_literals(F::cast_from(literal))] -impl Default for FBGenericConfig { +impl Default for InsertionConfig { fn default() -> Self { - FBGenericConfig { + InsertionConfig { tolerance: Default::default(), insertion_cutoff_factor: 1.0, refinement: Default::default(), @@ -88,7 +90,7 @@ } } -impl FBGenericConfig { +impl InsertionConfig { /// Check if merging should be attempted this iteration pub fn merge_now(&self, state: &AlgIteratorIteration) -> bool { self.merging.enabled && state.iteration() % self.merge_every == 0 @@ -96,20 +98,30 @@ /// Returns the final merging method pub fn final_merging_method(&self) -> SpikeMergingMethod { - SpikeMergingMethod { - enabled: self.final_merging, - ..self.merging - } + SpikeMergingMethod { enabled: self.final_merging, ..self.merging } } } +/// Available proximal terms +#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub enum ProxTerm { + /// Partial-to-wave operator 𝒟. + Wave, + /// Radon-norm squared + RadonSquared, +} + /// Trait for proximal penalties -pub trait ProxPenalty +pub trait ProxPenalty where F: Float + ToNalgebraRealField, - Reg: RegTerm, + Reg: RegTerm, + Domain: Space + Clone, { - type ReturnMapping: RealMapping; + type ReturnMapping: Mapping; + + /// Returns the type of this proximality penalty + fn prox_type() -> ProxTerm; /// Insert new spikes into `μ` to approximately satisfy optimality conditions /// with the forward step term fixed to `τv`. @@ -120,21 +132,21 @@ /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same /// spike locations, while `ν_delta` may have different locations. /// - /// `τv` is mutable to allow [`alg_tools::bisection_tree::BTFN`] refinement. - /// Actual values of `τv` are not supposed to be mutated. + /// `τv` is mutable to allow [`alg_tools::bounds::MinMaxMapping`] optimisation to + /// refine data. Actual values of `τv` are not supposed to be mutated. fn insert_and_reweigh( &self, - μ: &mut RNDM, + μ: &mut DiscreteMeasure, τv: &mut PreadjointCodomain, - μ_base: &RNDM, - ν_delta: Option<&RNDM>, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, reg: &Reg, state: &AlgIteratorIteration, - stats: &mut IterInfo, - ) -> (Option, bool) + stats: &mut IterInfo, + ) -> DynResult<(Option, bool)> where I: AlgIterator; @@ -145,15 +157,15 @@ /// is set. fn merge_spikes( &self, - μ: &mut RNDM, + μ: &mut DiscreteMeasure, τv: &mut PreadjointCodomain, - μ_base: &RNDM, - ν_delta: Option<&RNDM>, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, reg: &Reg, - fitness: Option) -> F>, + fitness: Option) -> F>, ) -> usize; /// Merge spikes, if possible. @@ -162,13 +174,13 @@ #[inline] fn merge_spikes_no_fitness( &self, - μ: &mut RNDM, + μ: &mut DiscreteMeasure, τv: &mut PreadjointCodomain, - μ_base: &RNDM, - ν_delta: Option<&RNDM>, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, reg: &Reg, ) -> usize { /// This is a hack to create a `None` of same type as a `Some` @@ -186,7 +198,31 @@ ε, config, reg, - into_none(Some(|_: &RNDM| F::ZERO)), + into_none(Some(|_: &DiscreteMeasure| F::ZERO)), ) } } + +/// Trait to calculate step length bound by `Dat` when the proximal penalty is `Self`, +/// which is typically also a [`ProxPenalty`]. If it is given by a (squared) norm $\|.\|_*$, and +/// and `Dat` respresents the function $f$, then this trait should calculate `L` such that +/// $\|f'(x)-f'(y)\| ≤ L\|x-y\|_*, where the step length is supposed to satisfy $τ L ≤ 1$. +pub trait StepLengthBound { + /// Returns $L$. + fn step_length_bound(&self, f: &Dat) -> DynResult; +} + +/// A variant of [`StepLengthBound`] for step length parameters for [`Pair`]s of variables. +pub trait StepLengthBoundPair { + fn step_length_bound_pair(&self, f: &Dat) -> DynResult<(F, F)>; +} + +/// Trait to calculate step length bound by the operator `A` when the proximal penalty is `Self`, +/// which is typically also a [`ProxPenalty`]. If it is given by a (squared) norm $\|.\|_*$, +/// then this trait should calculate `L` such that +/// $\|Ax\| ≤ L\|x\|_*, where the primal-dual step lengths are supposed to satisfy $τσ L^2 ≤ 1$. +/// The domain needs to be specified here, because A can operate on various domains. +pub trait StepLengthBoundPD { + /// Returns $L$. + fn step_length_bound_pd(&self, f: &A) -> DynResult; +} diff -r 9738b51d90d7 -r 4f468d35fa29 src/prox_penalty/radon_squared.rs --- a/src/prox_penalty/radon_squared.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/prox_penalty/radon_squared.rs Thu Feb 26 11:38:43 2026 -0500 @@ -4,81 +4,66 @@ Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. */ -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; +use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD}; +use crate::dataterm::QuadraticDataTerm; +use crate::forward_model::ForwardModel; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon}; +use crate::regularisation::RegTerm; +use crate::types::*; +use alg_tools::bounds::MinMaxMapping; +use alg_tools::error::DynResult; +use alg_tools::instance::{Instance, Space}; +use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; +use alg_tools::linops::BoundedLinear; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::{Norm, L2}; +use anyhow::ensure; use nalgebra::DVector; - -use alg_tools::iterate::{ - AlgIteratorIteration, - AlgIterator -}; -use alg_tools::norms::{L2, Norm}; -use alg_tools::linops::Mapping; -use alg_tools::bisection_tree::{ - BTFN, - Bounds, - BTSearch, - SupportGenerator, - LocalAnalysis, -}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; - -use crate::types::*; -use crate::measures::{ - RNDM, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::SpikeMerging; -use crate::regularisation::RegTerm; -use crate::forward_model::{ - ForwardModel, - AdjointProductBoundedBy -}; -use super::{ - FBGenericConfig, - ProxPenalty, -}; +use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; /// Radon-norm squared proximal penalty -#[derive(Copy,Clone,Serialize,Deserialize)] +#[derive(Copy, Clone, Serialize, Deserialize)] pub struct RadonSquared; #[replace_float_literals(F::cast_from(literal))] -impl -ProxPenalty, Reg, N> for RadonSquared +impl ProxPenalty for RadonSquared where - F : Float + ToNalgebraRealField, - GA : SupportGenerator + Clone, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - Reg : RegTerm, - RNDM : SpikeMerging, + Domain: Space + Clone + PartialEq + 'static, + for<'a> &'a Domain: Instance, + F: Float + ToNalgebraRealField, + M: MinMaxMapping, + Reg: RegTerm, + DiscreteMeasure: SpikeMerging, { - type ReturnMapping = BTFN; + type ReturnMapping = M; + + fn prox_type() -> ProxTerm { + ProxTerm::RadonSquared + } fn insert_and_reweigh( &self, - μ : &mut RNDM, - τv : &mut BTFN, - μ_base : &RNDM, - ν_delta: Option<&RNDM>, - τ : F, - ε : F, - config : &FBGenericConfig, - reg : &Reg, - _state : &AlgIteratorIteration, - stats : &mut IterInfo, - ) -> (Option, bool) + μ: &mut DiscreteMeasure, + τv: &mut M, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, + τ: F, + ε: F, + config: &InsertionConfig, + reg: &Reg, + _state: &AlgIteratorIteration, + stats: &mut IterInfo, + ) -> DynResult<(Option, bool)> where - I : AlgIterator + I: AlgIterator, { let mut y = μ_base.masses_dvector(); - assert!(μ_base.len() <= μ.len()); - + ensure!(μ_base.len() <= μ.len()); + 'i_and_w: for i in 0..=1 { // Optimise weights if μ.len() > 0 { @@ -86,11 +71,13 @@ // from the beginning of the iteration are all contained in the immutable c and g. // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional // problems have not yet been updated to sign change. - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); + let g̃ = DVector::from_iterator( + μ.len(), + μ.iter_locations() + .map(|ζ| -F::to_nalgebra_mixed(τv.apply(ζ))), + ); let mut x = μ.masses_dvector(); - y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len()-y.len()))); + y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len() - y.len()))); assert_eq!(y.len(), x.len()); // Solve finite-dimensional subproblem. // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to @@ -101,51 +88,49 @@ μ.set_masses_dvector(&x); } - if i>0 { + if i > 0 { // Simple debugging test to see if more inserts would be needed. Doesn't seem so. //let n = μ.dist_matching(μ_base); //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); - break 'i_and_w + break 'i_and_w; } - + // Calculate ‖μ - μ_base‖_ℳ // TODO: This assumes that ν_delta has no common locations with μ-μ_base. let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon)); - + // Find a spike to insert, if needed. // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { - None => { break 'i_and_w }, + None => break 'i_and_w, Some((ξ, _v_ξ, _in_bounds)) => { // Weight is found out by running the finite-dimensional optimisation algorithm // above - *μ += DeltaMeasure { x : ξ, α : 0.0 }; + *μ += DeltaMeasure { x: ξ, α: 0.0 }; stats.inserted += 1; } }; } - (None, true) + Ok((None, true)) } fn merge_spikes( &self, - μ : &mut RNDM, - τv : &mut BTFN, - μ_base : &RNDM, - ν_delta: Option<&RNDM>, - τ : F, - ε : F, - config : &FBGenericConfig, - reg : &Reg, - fitness : Option) -> F>, - ) -> usize - { + μ: &mut DiscreteMeasure, + τv: &mut M, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, + τ: F, + ε: F, + config: &InsertionConfig, + reg: &Reg, + fitness: Option) -> F>, + ) -> usize { if config.fitness_merging { if let Some(f) = fitness { - return μ.merge_spikes_fitness(config.merging, f, |&v| v) - .1 + return μ.merge_spikes_fitness(config.merging, f, |&v| v).1; } } μ.merge_spikes(config.merging, |μ_candidate| { @@ -167,16 +152,27 @@ } } - -impl AdjointProductBoundedBy, RadonSquared> -for A +#[replace_float_literals(F::cast_from(literal))] +impl<'a, F, A, Domain> StepLengthBound> for RadonSquared where - F : Float, - A : ForwardModel, F> + F: Float + ToNalgebraRealField, + Domain: Space + Norm, + A: ForwardModel + BoundedLinear, { - type FloatType = F; - - fn adjoint_product_bound(&self, _ : &RadonSquared) -> Option { - self.opnorm_bound(Radon, L2).powi(2).into() + fn step_length_bound(&self, f: &QuadraticDataTerm) -> DynResult { + // TODO: direct squared calculation + Ok(f.operator().opnorm_bound(Radon, L2)?.powi(2)) } } + +#[replace_float_literals(F::cast_from(literal))] +impl<'a, F, A, Domain> StepLengthBoundPD> for RadonSquared +where + Domain: Space + Clone + PartialEq + 'static, + F: Float + ToNalgebraRealField, + A: BoundedLinear, Radon, L2, F>, +{ + fn step_length_bound_pd(&self, opA: &A) -> DynResult { + opA.opnorm_bound(Radon, L2) + } +} diff -r 9738b51d90d7 -r 4f468d35fa29 src/prox_penalty/wave.rs --- a/src/prox_penalty/wave.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/prox_penalty/wave.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,84 +2,70 @@ Basic proximal penalty based on convolution operators $𝒟$. */ -use numeric_literals::replace_float_literals; -use nalgebra::DVector; -use colored::Colorize; - -use alg_tools::types::*; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, RealMapping}; +use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD}; +use crate::dataterm::QuadraticDataTerm; +use crate::forward_model::ForwardModel; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon}; +use crate::regularisation::RegTerm; +use crate::seminorms::DiscreteMeasureOp; +use crate::types::IterInfo; +use alg_tools::bounds::MinMaxMapping; +use alg_tools::error::DynResult; +use alg_tools::iterate::{AlgIterator, AlgIteratorIteration}; +use alg_tools::linops::BoundedLinear; +use alg_tools::mapping::{Instance, Mapping, Space}; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::Linfinity; -use alg_tools::iterate::{ - AlgIteratorIteration, - AlgIterator, -}; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTSearch, - SupportGenerator, - LocalAnalysis, - BothGenerators, -}; -use crate::measures::{ - RNDM, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::{ - SpikeMerging, -}; -use crate::seminorms::DiscreteMeasureOp; -use crate::types::{ - IterInfo, -}; -use crate::regularisation::RegTerm; -use super::{ProxPenalty, FBGenericConfig}; +use alg_tools::norms::{Linfinity, Norm, NormExponent, L2}; +use alg_tools::types::*; +use colored::Colorize; +use nalgebra::DVector; +use numeric_literals::replace_float_literals; #[replace_float_literals(F::cast_from(literal))] -impl -ProxPenalty, Reg, N> for 𝒟 +impl ProxPenalty for 𝒟 where - F : Float + ToNalgebraRealField, - GA : SupportGenerator + Clone, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - K : RealMapping + LocalAnalysis, N>, - Reg : RegTerm, - RNDM : SpikeMerging, + Domain: Space + Clone + PartialEq + 'static, + for<'a> &'a Domain: Instance, + F: Float + ToNalgebraRealField, + 𝒟: DiscreteMeasureOp, + 𝒟::Codomain: Mapping, + M: Mapping, + for<'a> &'a M: std::ops::Add<𝒟::PreCodomain, Output = O>, + O: MinMaxMapping, + Reg: RegTerm, + DiscreteMeasure: SpikeMerging, { - type ReturnMapping = BTFN, BTA, N>; + type ReturnMapping = O; + + fn prox_type() -> ProxTerm { + ProxTerm::Wave + } fn insert_and_reweigh( &self, - μ : &mut RNDM, - τv : &mut BTFN, - μ_base : &RNDM, - ν_delta: Option<&RNDM>, - τ : F, - ε : F, - config : &FBGenericConfig, - reg : &Reg, - state : &AlgIteratorIteration, - stats : &mut IterInfo, - ) -> (Option, BTA, N>>, bool) + μ: &mut DiscreteMeasure, + τv: &mut M, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, + τ: F, + ε: F, + config: &InsertionConfig, + reg: &Reg, + state: &AlgIteratorIteration, + stats: &mut IterInfo, + ) -> DynResult<(Option, bool)> where - I : AlgIterator + I: AlgIterator, { - - let op𝒟norm = self.opnorm_bound(Radon, Linfinity); + let op𝒟norm = self.opnorm_bound(Radon, Linfinity)?; // Maximum insertion count and measure difference calculation depend on insertion style. - let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { - (i, Some((l, k))) if i <= l => (k, false), - _ => (config.max_insertions, !state.is_quiet()), - }; + let (max_insertions, warn_insertions) = + match (state.iteration(), config.bootstrap_insertions) { + (i, Some((l, k))) if i <= l => (k, false), + _ => (config.max_insertions, !state.is_quiet()), + }; let ω0 = match ν_delta { None => self.apply(μ_base), @@ -95,10 +81,12 @@ // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional // problems have not yet been updated to sign change. let à = self.findim_matrix(μ.iter_locations()); - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) - .map(F::to_nalgebra_mixed)); + let g̃ = DVector::from_iterator( + μ.len(), + μ.iter_locations() + .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) + .map(F::to_nalgebra_mixed), + ); let mut x = μ.masses_dvector(); // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. @@ -117,10 +105,11 @@ // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality // conditions in the predual space, and finding new points for insertion, if necessary. - let mut d = &*τv + match ν_delta { - None => self.preapply(μ.sub_matching(μ_base)), - Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν) - }; + let mut d = &*τv + + match ν_delta { + None => self.preapply(μ.sub_matching(μ_base)), + Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν), + }; // If no merging heuristic is used, let's be more conservative about spike insertion, // and skip it after first round. If merging is done, being more greedy about spike @@ -132,20 +121,19 @@ }; // Find a spike to insert, if needed - let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( - &mut d, τ, ε, skip_by_rough_check, config - ) { - None => break 'insertion (true, d), - Some(res) => res, - }; + let (ξ, _v_ξ, in_bounds) = + match reg.find_tolerance_violation(&mut d, τ, ε, skip_by_rough_check, config) { + None => break 'insertion (true, d), + Some(res) => res, + }; // Break if maximum insertion count reached if count >= max_insertions { - break 'insertion (in_bounds, d) + break 'insertion (in_bounds, d); } // No point in optimising the weight here; the finite-dimensional algorithm is fast. - *μ += DeltaMeasure { x : ξ, α : 0.0 }; + *μ += DeltaMeasure { x: ξ, α: 0.0 }; count += 1; stats.inserted += 1; }; @@ -153,39 +141,76 @@ if !within_tolerances && warn_insertions { // Complain (but continue) if we failed to get within tolerances // by inserting more points. - let err = format!("Maximum insertions reached without achieving \ - subproblem solution tolerance"); + let err = format!( + "Maximum insertions reached without achieving \ + subproblem solution tolerance" + ); println!("{}", err.red()); } - (Some(d), within_tolerances) + Ok((Some(d), within_tolerances)) } fn merge_spikes( &self, - μ : &mut RNDM, - τv : &mut BTFN, - μ_base : &RNDM, - ν_delta: Option<&RNDM>, - τ : F, - ε : F, - config : &FBGenericConfig, - reg : &Reg, - fitness : Option) -> F>, - ) -> usize - { + μ: &mut DiscreteMeasure, + τv: &mut M, + μ_base: &DiscreteMeasure, + ν_delta: Option<&DiscreteMeasure>, + τ: F, + ε: F, + config: &InsertionConfig, + reg: &Reg, + fitness: Option) -> F>, + ) -> usize { if config.fitness_merging { if let Some(f) = fitness { - return μ.merge_spikes_fitness(config.merging, f, |&v| v) - .1 + return μ.merge_spikes_fitness(config.merging, f, |&v| v).1; } } μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &*τv + self.preapply(match ν_delta { - None => μ_candidate.sub_matching(μ_base), - Some(ν) => μ_candidate.sub_matching(μ_base) - ν, - }); + let mut d = &*τv + + self.preapply(match ν_delta { + None => μ_candidate.sub_matching(μ_base), + Some(ν) => μ_candidate.sub_matching(μ_base) - ν, + }); reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) }) } } + +#[replace_float_literals(F::cast_from(literal))] +impl<'a, F, A, 𝒟, Domain> StepLengthBound, A>> + for 𝒟 +where + Domain: Space + Clone + PartialEq + 'static, + F: Float + ToNalgebraRealField, + 𝒟: DiscreteMeasureOp, + A: ForwardModel, F> + + for<'b> BoundedLinear, &'b 𝒟, L2, F>, + DiscreteMeasure: for<'b> Norm<&'b 𝒟, F>, + for<'b> &'b 𝒟: NormExponent, +{ + fn step_length_bound( + &self, + f: &QuadraticDataTerm, A>, + ) -> DynResult { + // TODO: direct squared calculation + Ok(f.operator().opnorm_bound(self, L2)?.powi(2)) + } +} + +#[replace_float_literals(F::cast_from(literal))] +impl StepLengthBoundPD> for 𝒟 +where + Domain: Space + Clone + PartialEq + 'static, + F: Float + ToNalgebraRealField, + 𝒟: DiscreteMeasureOp, + A: for<'a> BoundedLinear, &'a 𝒟, L2, F>, + DiscreteMeasure: for<'a> Norm<&'a 𝒟, F>, + for<'b> &'b 𝒟: NormExponent, +{ + fn step_length_bound_pd(&self, opA: &A) -> DynResult { + opA.opnorm_bound(self, L2) + } +} diff -r 9738b51d90d7 -r 4f468d35fa29 src/rand_distr.rs --- a/src/rand_distr.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/rand_distr.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,35 +1,42 @@ //! Random distribution wrappers and implementations +use alg_tools::types::*; use numeric_literals::replace_float_literals; use rand::Rng; -use rand_distr::{Distribution, Normal, StandardNormal, NormalError}; -use serde::{Serialize, Deserialize}; -use serde::ser::{Serializer, SerializeStruct}; -use alg_tools::types::*; +use rand_distr::{Distribution, Normal, NormalError, StandardNormal}; +use serde::ser::{SerializeStruct, Serializer}; +use serde::{Deserialize, Serialize}; /// Wrapper for [`Normal`] that can be serialized by serde. #[derive(Debug)] -pub struct SerializableNormal(Normal) -where StandardNormal : Distribution; +pub struct SerializableNormal(Normal) +where + StandardNormal: Distribution; -impl Distribution for SerializableNormal -where StandardNormal : Distribution { +impl Distribution for SerializableNormal +where + StandardNormal: Distribution, +{ fn sample(&self, rng: &mut R) -> T where - R : Rng + ?Sized - { self.0.sample(rng) } + R: Rng + ?Sized, + { + self.0.sample(rng) + } } -impl SerializableNormal -where StandardNormal : Distribution { - pub fn new(mean : T, std_dev : T) -> Result, NormalError> { +impl SerializableNormal +where + StandardNormal: Distribution, +{ + pub fn new(mean: T, std_dev: T) -> Result, NormalError> { Ok(SerializableNormal(Normal::new(mean, std_dev)?)) } } impl Serialize for SerializableNormal where - StandardNormal : Distribution, + StandardNormal: Distribution, F: Float + Serialize, { fn serialize(&self, serializer: S) -> Result @@ -48,11 +55,11 @@ /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding /// probabilities $\\{1-p, p/2, p/2\\}$. #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub struct SaltAndPepper{ +pub struct SaltAndPepper { /// The magnitude parameter $m$ - magnitude : T, + magnitude: T, /// The probability parameter $p$ - probability : T + probability: T, } /// Error for [`SaltAndPepper`]. @@ -66,15 +73,16 @@ impl std::fmt::Display for SaltAndPepperError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(match self { - SaltAndPepperError::InvalidProbability => - " The probability parameter is not in the range [0, 1].", + SaltAndPepperError::InvalidProbability => { + " The probability parameter is not in the range [0, 1]." + } }) } } #[replace_float_literals(T::cast_from(literal))] -impl SaltAndPepper { - pub fn new(magnitude : T, probability : T) -> Result, SaltAndPepperError> { +impl SaltAndPepper { + pub fn new(magnitude: T, probability: T) -> Result, SaltAndPepperError> { if probability > 1.0 || probability < 0.0 { Err(SaltAndPepperError::InvalidProbability) } else { @@ -84,16 +92,16 @@ } #[replace_float_literals(T::cast_from(literal))] -impl Distribution for SaltAndPepper { +impl Distribution for SaltAndPepper { fn sample(&self, rng: &mut R) -> T where - R : Rng + ?Sized + R: Rng + ?Sized, { - let (p, sign) : (float, bool) = rng.gen(); + let (p, sign): (float, bool) = rng.random(); match (p < self.probability.as_(), sign) { - (false, _) => 0.0, - (true, true) => self.magnitude, - (true, false) => -self.magnitude, + (false, _) => 0.0, + (true, true) => self.magnitude, + (true, false) => -self.magnitude, } } } diff -r 9738b51d90d7 -r 4f468d35fa29 src/regularisation.rs --- a/src/regularisation.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/regularisation.rs Thu Feb 26 11:38:43 2026 -0500 @@ -4,12 +4,12 @@ #[allow(unused_imports)] // Used by documentation. use crate::fb::pointsource_fb_reg; -use crate::fb::FBGenericConfig; -use crate::measures::{DeltaMeasure, Radon, RNDM}; +use crate::fb::InsertionConfig; +use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM}; #[allow(unused_imports)] // Used by documentation. use crate::sliding_fb::pointsource_sliding_fb_reg; use crate::types::*; -use alg_tools::instance::Instance; +use alg_tools::instance::{Instance, Space}; use alg_tools::linops::Mapping; use alg_tools::loc::Loc; use alg_tools::norms::Norm; @@ -20,9 +20,7 @@ l1squared_nonneg::l1squared_nonneg, l1squared_unconstrained::l1squared_unconstrained, nonneg::quadratic_nonneg, unconstrained::quadratic_unconstrained, }; -use alg_tools::bisection_tree::{ - BTSearch, Bounded, Bounds, LocalAnalysis, P2Minimise, SupportGenerator, BTFN, -}; +use alg_tools::bounds::{Bounds, MinMaxMapping}; use alg_tools::iterate::AlgIteratorFactory; use alg_tools::nalgebra_support::ToNalgebraRealField; use nalgebra::{DMatrix, DVector}; @@ -34,7 +32,7 @@ /// /// The only member of the struct is the regularisation parameter α. #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub struct NonnegRadonRegTerm(pub F /* α */); +pub struct NonnegRadonRegTerm(pub F /* α */); impl<'a, F: Float> NonnegRadonRegTerm { /// Returns the regularisation parameter @@ -44,12 +42,12 @@ } } -impl<'a, F: Float, const N: usize> Mapping> for NonnegRadonRegTerm { +impl<'a, F: Float, const N: usize> Mapping> for NonnegRadonRegTerm { type Codomain = F; fn apply(&self, μ: I) -> F where - I: Instance>, + I: Instance>, { self.α() * μ.eval(|x| x.norm(Radon)) } @@ -59,7 +57,7 @@ /// /// The only member of the struct is the regularisation parameter α. #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub struct RadonRegTerm(pub F /* α */); +pub struct RadonRegTerm(pub F /* α */); impl<'a, F: Float> RadonRegTerm { /// Returns the regularisation parameter @@ -69,12 +67,12 @@ } } -impl<'a, F: Float, const N: usize> Mapping> for RadonRegTerm { +impl<'a, F: Float, const N: usize> Mapping> for RadonRegTerm { type Codomain = F; fn apply(&self, μ: I) -> F where - I: Instance>, + I: Instance>, { self.α() * μ.eval(|x| x.norm(Radon)) } @@ -82,19 +80,19 @@ /// Regularisation term configuration #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -pub enum Regularisation { +pub enum Regularisation { /// $α \\|μ\\|\_{ℳ(Ω)}$ Radon(F), /// $α\\|μ\\|\_{ℳ(Ω)} + δ_{≥ 0}(μ)$ NonnegRadon(F), } -impl<'a, F: Float, const N: usize> Mapping> for Regularisation { +impl<'a, F: Float, const N: usize> Mapping> for Regularisation { type Codomain = F; fn apply(&self, μ: I) -> F where - I: Instance>, + I: Instance>, { match *self { Self::Radon(α) => RadonRegTerm(α).apply(μ), @@ -104,8 +102,10 @@ } /// Abstraction of regularisation terms. -pub trait RegTerm: - Mapping, Codomain = F> +pub trait RegTerm: Mapping, Codomain = F> +where + Domain: Space + Clone, + F: Float + ToNalgebraRealField, { /// Approximately solve the problem ///
$$ @@ -125,7 +125,7 @@ x: &mut DVector, mA_normest: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> usize; /// Approximately solve the problem @@ -142,7 +142,7 @@ τ: F, x: &mut DVector, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> usize; /// Find a point where `d` may violate the tolerance `ε`. @@ -154,18 +154,16 @@ /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check /// terminating early. Otherwise returns a possibly violating point, the value of `d` there, /// and a boolean indicating whether the found point is in bounds. - fn find_tolerance_violation( + fn find_tolerance_violation( &self, - d: &mut BTFN, + d: &mut M, τ: F, ε: F, skip_by_rough_check: bool, - config: &FBGenericConfig, - ) -> Option<(Loc, F, bool)> + config: &InsertionConfig, + ) -> Option<(Domain, F, bool)> where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, { self.find_tolerance_violation_slack(d, τ, ε, skip_by_rough_check, config, F::ZERO) } @@ -182,35 +180,31 @@ /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check /// terminating early. Otherwise returns a possibly violating point, the value of `d` there, /// and a boolean indicating whether the found point is in bounds. - fn find_tolerance_violation_slack( + fn find_tolerance_violation_slack( &self, - d: &mut BTFN, + d: &mut M, τ: F, ε: F, skip_by_rough_check: bool, - config: &FBGenericConfig, + config: &InsertionConfig, slack: F, - ) -> Option<(Loc, F, bool)> + ) -> Option<(Domain, F, bool)> where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>; + M: MinMaxMapping; /// Verify that `d` is in bounds `ε` for a merge candidate `μ` /// /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser. - fn verify_merge_candidate( + fn verify_merge_candidate( &self, - d: &mut BTFN, - μ: &RNDM, + d: &mut M, + μ: &DiscreteMeasure, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> bool where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>; + M: MinMaxMapping; /// Verify that `d` is in bounds `ε` for a merge candidate `μ` /// @@ -220,19 +214,17 @@ /// same coordinates at same agreeing indices. /// /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser. - fn verify_merge_candidate_radonsq( + fn verify_merge_candidate_radonsq( &self, - d: &mut BTFN, - μ: &RNDM, + d: &mut M, + μ: &DiscreteMeasure, τ: F, ε: F, - config: &FBGenericConfig, - radon_μ: &RNDM, + config: &InsertionConfig, + radon_μ: &DiscreteMeasure, ) -> bool where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>; + M: MinMaxMapping; /// TODO: document this fn target_bounds(&self, τ: F, ε: F) -> Option>; @@ -244,32 +236,33 @@ } /// Abstraction of regularisation terms for [`pointsource_sliding_fb_reg`]. -pub trait SlidingRegTerm: RegTerm { +pub trait SlidingRegTerm: RegTerm +where + Domain: Space + Clone, + F: Float + ToNalgebraRealField, +{ /// Calculate $τ[w(z) - w(y)]$ for some w in the subdifferential of the regularisation /// term, such that $-ε ≤ τw - d ≤ ε$. - fn goodness( + fn goodness( &self, - d: &mut BTFN, - μ: &RNDM, - y: &Loc, - z: &Loc, + d: &mut M, + μ: &DiscreteMeasure, + y: &Domain, + z: &Domain, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> F where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>; + M: MinMaxMapping; /// Convert bound on the regulariser to a bond on the Radon norm fn radon_norm_bound(&self, b: F) -> F; } #[replace_float_literals(F::cast_from(literal))] -impl RegTerm for NonnegRadonRegTerm -where - Cube: P2Minimise, F>, +impl RegTerm, F> + for NonnegRadonRegTerm { fn solve_findim( &self, @@ -279,7 +272,7 @@ x: &mut DVector, mA_normest: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> usize { let inner_tolerance = ε * config.inner.tolerance_mult; let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); @@ -293,7 +286,7 @@ τ: F, x: &mut DVector, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> usize { let inner_tolerance = ε * config.inner.tolerance_mult; let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); @@ -301,52 +294,54 @@ } #[inline] - fn find_tolerance_violation_slack( + fn find_tolerance_violation_slack( &self, - d: &mut BTFN, + d: &mut M, τ: F, ε: F, skip_by_rough_check: bool, - config: &FBGenericConfig, + config: &InsertionConfig, slack: F, - ) -> Option<(Loc, F, bool)> + ) -> Option<(Loc, F, bool)> where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let τα = τ * self.α(); let keep_above = -τα - slack - ε; let minimise_below = -τα - slack - ε * config.insertion_cutoff_factor; let refinement_tolerance = ε * config.refinement.tolerance_mult; + // println!( + // "keep_above: {keep_above}, rough lower bound: {}, tolerance: {ε}, slack: {slack}, τα: {τα}", + // d.bounds().lower() + // ); + // If preliminary check indicates that we are in bounds, and if it otherwise matches // the insertion strategy, skip insertion. if skip_by_rough_check && d.bounds().lower() >= keep_above { None } else { // If the rough check didn't indicate no insertion needed, find minimising point. - d.minimise_below( + let res = d.minimise_below( minimise_below, refinement_tolerance, config.refinement.max_steps, - ) - .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ >= keep_above)) + ); + + res.map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ >= keep_above)) } } - fn verify_merge_candidate( + fn verify_merge_candidate( &self, - d: &mut BTFN, - μ: &RNDM, + d: &mut M, + μ: &RNDM, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> bool where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let τα = τ * self.α(); let refinement_tolerance = ε * config.refinement.tolerance_mult; @@ -367,19 +362,17 @@ )); } - fn verify_merge_candidate_radonsq( + fn verify_merge_candidate_radonsq( &self, - d: &mut BTFN, - μ: &RNDM, + d: &mut M, + μ: &RNDM, τ: F, ε: F, - config: &FBGenericConfig, - radon_μ: &RNDM, + config: &InsertionConfig, + radon_μ: &RNDM, ) -> bool where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let τα = τ * self.α(); let refinement_tolerance = ε * config.refinement.tolerance_mult; @@ -426,24 +419,21 @@ } #[replace_float_literals(F::cast_from(literal))] -impl SlidingRegTerm for NonnegRadonRegTerm -where - Cube: P2Minimise, F>, +impl SlidingRegTerm, F> + for NonnegRadonRegTerm { - fn goodness( + fn goodness( &self, - d: &mut BTFN, - _μ: &RNDM, - y: &Loc, - z: &Loc, + d: &mut M, + _μ: &RNDM, + y: &Loc, + z: &Loc, τ: F, ε: F, - _config: &FBGenericConfig, + _config: &InsertionConfig, ) -> F where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let w = |x| 1.0.min((ε + d.apply(x)) / (τ * self.α())); w(z) - w(y) @@ -455,10 +445,7 @@ } #[replace_float_literals(F::cast_from(literal))] -impl RegTerm for RadonRegTerm -where - Cube: P2Minimise, F>, -{ +impl RegTerm, F> for RadonRegTerm { fn solve_findim( &self, mA: &DMatrix, @@ -467,7 +454,7 @@ x: &mut DVector, mA_normest: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> usize { let inner_tolerance = ε * config.inner.tolerance_mult; let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); @@ -481,26 +468,24 @@ τ: F, x: &mut DVector, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> usize { let inner_tolerance = ε * config.inner.tolerance_mult; let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); l1squared_unconstrained(y, g, τ * self.α(), 1.0, x, &config.inner, inner_it) } - fn find_tolerance_violation_slack( + fn find_tolerance_violation_slack( &self, - d: &mut BTFN, + d: &mut M, τ: F, ε: F, skip_by_rough_check: bool, - config: &FBGenericConfig, + config: &InsertionConfig, slack: F, - ) -> Option<(Loc, F, bool)> + ) -> Option<(Loc, F, bool)> where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let τα = τ * self.α(); let keep_below = τα + slack + ε; @@ -541,18 +526,16 @@ } } - fn verify_merge_candidate( + fn verify_merge_candidate( &self, - d: &mut BTFN, - μ: &RNDM, + d: &mut M, + μ: &RNDM, τ: F, ε: F, - config: &FBGenericConfig, + config: &InsertionConfig, ) -> bool where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let τα = τ * self.α(); let refinement_tolerance = ε * config.refinement.tolerance_mult; @@ -585,19 +568,17 @@ )); } - fn verify_merge_candidate_radonsq( + fn verify_merge_candidate_radonsq( &self, - d: &mut BTFN, - μ: &RNDM, + d: &mut M, + μ: &RNDM, τ: F, ε: F, - config: &FBGenericConfig, - radon_μ: &RNDM, + config: &InsertionConfig, + radon_μ: &RNDM, ) -> bool where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let τα = τ * self.α(); let refinement_tolerance = ε * config.refinement.tolerance_mult; @@ -650,24 +631,21 @@ } #[replace_float_literals(F::cast_from(literal))] -impl SlidingRegTerm for RadonRegTerm -where - Cube: P2Minimise, F>, +impl SlidingRegTerm, F> + for RadonRegTerm { - fn goodness( + fn goodness( &self, - d: &mut BTFN, - _μ: &RNDM, - y: &Loc, - z: &Loc, + d: &mut M, + _μ: &RNDM, + y: &Loc, + z: &Loc, τ: F, ε: F, - _config: &FBGenericConfig, + _config: &InsertionConfig, ) -> F where - BT: BTSearch>, - G: SupportGenerator, - G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + M: MinMaxMapping, F>, { let α = self.α(); let w = |x| { diff -r 9738b51d90d7 -r 4f468d35fa29 src/run.rs --- a/src/run.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/run.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,133 +2,81 @@ This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. */ -use numeric_literals::replace_float_literals; -use colored::Colorize; -use serde::{Serialize, Deserialize}; -use serde_json; -use nalgebra::base::DVector; -use std::hash::Hash; -use chrono::{DateTime, Utc}; -use cpu_time::ProcessTime; -use clap::ValueEnum; -use std::collections::HashMap; -use std::time::Instant; - -use rand::prelude::{ - StdRng, - SeedableRng -}; -use rand_distr::Distribution; - -use alg_tools::bisection_tree::*; -use alg_tools::iterate::{ - Timed, - AlgIteratorOptions, - Verbose, - AlgIteratorFactory, - LoggingIteratorFactory, - TimingIteratorFactory, - BasicAlgIteratorFactory, -}; -use alg_tools::logger::Logger; -use alg_tools::error::{ - DynError, - DynResult, -}; -use alg_tools::tabledump::TableDump; -use alg_tools::sets::Cube; -use alg_tools::mapping::{ - RealMapping, - DifferentiableMapping, - DifferentiableRealMapping, - Instance -}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::euclidean::Euclidean; -use alg_tools::lingrid::{lingrid, LinSpace}; -use alg_tools::sets::SetOrd; -use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/}; -use alg_tools::discrete_gradient::{Grad, ForwardNeumann}; -use alg_tools::convex::Zero; -use alg_tools::maputil::map3; -use alg_tools::direct_product::Pair; - -use crate::kernels::*; -use crate::types::*; -use crate::measures::*; -use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; -use crate::forward_model::*; +use crate::fb::{pointsource_fb_reg, pointsource_fista_reg, FBConfig, InsertionConfig}; use crate::forward_model::sensor_grid::{ - SensorGrid, - SensorGridBT, //SensorGridBTFN, Sensor, + SensorGrid, + SensorGridBT, Spread, }; - -use crate::fb::{ - FBConfig, - FBGenericConfig, - pointsource_fb_reg, - pointsource_fista_reg, +use crate::forward_model::*; +use crate::forward_pdps::{pointsource_fb_pair, pointsource_forward_pdps_pair, ForwardPDPSConfig}; +use crate::frank_wolfe::{pointsource_fw_reg, FWConfig, FWVariant, RegTermFW}; +use crate::kernels::*; +use crate::measures::merging::{SpikeMerging, SpikeMergingMethod}; +use crate::measures::*; +use crate::pdps::{pointsource_pdps_reg, PDPSConfig}; +use crate::plot::*; +use crate::prox_penalty::{ + ProxPenalty, ProxTerm, RadonSquared, StepLengthBound, StepLengthBoundPD, StepLengthBoundPair, }; -use crate::sliding_fb::{ - SlidingFBConfig, - TransportConfig, - pointsource_sliding_fb_reg -}; +use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation, SlidingRegTerm}; +use crate::seminorms::*; +use crate::sliding_fb::{pointsource_sliding_fb_reg, SlidingFBConfig, TransportConfig}; use crate::sliding_pdps::{ - SlidingPDPSConfig, - pointsource_sliding_pdps_pair -}; -use crate::forward_pdps::{ - ForwardPDPSConfig, - pointsource_forward_pdps_pair + pointsource_sliding_fb_pair, pointsource_sliding_pdps_pair, SlidingPDPSConfig, }; -use crate::pdps::{ - PDPSConfig, - pointsource_pdps_reg, -}; -use crate::frank_wolfe::{ - FWConfig, - FWVariant, - pointsource_fw_reg, - //WeightOptim, +use crate::subproblem::{InnerMethod, InnerSettings}; +use crate::tolerance::Tolerance; +use crate::types::*; +use crate::{AlgorithmOverrides, CommandLineArgs}; +use alg_tools::bisection_tree::*; +use alg_tools::bounds::{Bounded, MinMaxMapping}; +use alg_tools::convex::{Conjugable, Norm222, Prox, Zero}; +use alg_tools::direct_product::Pair; +use alg_tools::discrete_gradient::{ForwardNeumann, Grad}; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::euclidean::{ClosedEuclidean, Euclidean}; +use alg_tools::iterate::{ + AlgIteratorFactory, AlgIteratorOptions, BasicAlgIteratorFactory, LoggingIteratorFactory, Timed, + TimingIteratorFactory, ValueIteratorFactory, Verbose, }; -use crate::subproblem::{InnerSettings, InnerMethod}; -use crate::seminorms::*; -use crate::plot::*; -use crate::{AlgorithmOverrides, CommandLineArgs}; -use crate::tolerance::Tolerance; -use crate::regularisation::{ - Regularisation, - RadonRegTerm, - NonnegRadonRegTerm -}; -use crate::dataterm::{ - L1, - L2Squared, +use alg_tools::lingrid::lingrid; +use alg_tools::linops::{IdOp, RowOp, AXPY}; +use alg_tools::logger::Logger; +use alg_tools::mapping::{ + DataTerm, DifferentiableMapping, DifferentiableRealMapping, Instance, RealMapping, }; -use crate::prox_penalty::{ - RadonSquared, - //ProxPenalty, -}; -use alg_tools::norms::{L2, NormExponent}; +use alg_tools::maputil::map3; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::{NormExponent, L1, L2}; use alg_tools::operator_arithmetic::Weighted; +use alg_tools::sets::Cube; +use alg_tools::sets::SetOrd; +use alg_tools::tabledump::TableDump; use anyhow::anyhow; +use chrono::{DateTime, Utc}; +use clap::ValueEnum; +use colored::Colorize; +use cpu_time::ProcessTime; +use nalgebra::base::DVector; +use numeric_literals::replace_float_literals; +use rand::prelude::{SeedableRng, StdRng}; +use rand_distr::Distribution; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::hash::Hash; +use std::time::Instant; +use thiserror::Error; -/// Available proximal terms -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum ProxTerm { - /// Partial-to-wave operator 𝒟. - Wave, - /// Radon-norm squared - RadonSquared -} +//#[cfg(feature = "pyo3")] +//use pyo3::pyclass; /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum AlgorithmConfig { +pub enum AlgorithmConfig { FB(FBConfig, ProxTerm), FISTA(FBConfig, ProxTerm), FW(FWConfig), @@ -138,91 +86,114 @@ SlidingPDPS(SlidingPDPSConfig, ProxTerm), } -fn unpack_tolerance(v : &Vec) -> Tolerance { +fn unpack_tolerance(v: &Vec) -> Tolerance { assert!(v.len() == 3); - Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } + Tolerance::Power { initial: v[0], factor: v[1], exponent: v[2] } } -impl AlgorithmConfig { +impl AlgorithmConfig { /// Override supported parameters based on the command line. - pub fn cli_override(self, cli : &AlgorithmOverrides) -> Self { - let override_merging = |g : SpikeMergingMethod| { - SpikeMergingMethod { - enabled : cli.merge.unwrap_or(g.enabled), - radius : cli.merge_radius.unwrap_or(g.radius), - interp : cli.merge_interp.unwrap_or(g.interp), - } + pub fn cli_override(self, cli: &AlgorithmOverrides) -> Self { + let override_merging = |g: SpikeMergingMethod| SpikeMergingMethod { + enabled: cli.merge.unwrap_or(g.enabled), + radius: cli.merge_radius.unwrap_or(g.radius), + interp: cli.merge_interp.unwrap_or(g.interp), }; - let override_fb_generic = |g : FBGenericConfig| { - FBGenericConfig { - bootstrap_insertions : cli.bootstrap_insertions - .as_ref() - .map_or(g.bootstrap_insertions, - |n| Some((n[0], n[1]))), - merge_every : cli.merge_every.unwrap_or(g.merge_every), - merging : override_merging(g.merging), - final_merging : cli.final_merging.unwrap_or(g.final_merging), - fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), - tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), - .. g - } + let override_fb_generic = |g: InsertionConfig| InsertionConfig { + bootstrap_insertions: cli + .bootstrap_insertions + .as_ref() + .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))), + merge_every: cli.merge_every.unwrap_or(g.merge_every), + merging: override_merging(g.merging), + final_merging: cli.final_merging.unwrap_or(g.final_merging), + fitness_merging: cli.fitness_merging.unwrap_or(g.fitness_merging), + tolerance: cli + .tolerance + .as_ref() + .map(unpack_tolerance) + .unwrap_or(g.tolerance), + ..g }; - let override_transport = |g : TransportConfig| { - TransportConfig { - θ0 : cli.theta0.unwrap_or(g.θ0), - tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), - adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), - .. g - } + let override_transport = |g: TransportConfig| TransportConfig { + θ0: cli.theta0.unwrap_or(g.θ0), + tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), + adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), + ..g }; use AlgorithmConfig::*; match self { - FB(fb, prox) => FB(FBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - generic : override_fb_generic(fb.generic), - .. fb - }, prox), - FISTA(fb, prox) => FISTA(FBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - generic : override_fb_generic(fb.generic), - .. fb - }, prox), - PDPS(pdps, prox) => PDPS(PDPSConfig { - τ0 : cli.tau0.unwrap_or(pdps.τ0), - σ0 : cli.sigma0.unwrap_or(pdps.σ0), - acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - generic : override_fb_generic(pdps.generic), - .. pdps - }, prox), + FB(fb, prox) => FB( + FBConfig { + τ0: cli.tau0.unwrap_or(fb.τ0), + σp0: cli.sigmap0.unwrap_or(fb.σp0), + insertion: override_fb_generic(fb.insertion), + ..fb + }, + prox, + ), + FISTA(fb, prox) => FISTA( + FBConfig { + τ0: cli.tau0.unwrap_or(fb.τ0), + σp0: cli.sigmap0.unwrap_or(fb.σp0), + insertion: override_fb_generic(fb.insertion), + ..fb + }, + prox, + ), + PDPS(pdps, prox) => PDPS( + PDPSConfig { + τ0: cli.tau0.unwrap_or(pdps.τ0), + σ0: cli.sigma0.unwrap_or(pdps.σ0), + acceleration: cli.acceleration.unwrap_or(pdps.acceleration), + generic: override_fb_generic(pdps.generic), + ..pdps + }, + prox, + ), FW(fw) => FW(FWConfig { - merging : override_merging(fw.merging), - tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), - .. fw + merging: override_merging(fw.merging), + tolerance: cli + .tolerance + .as_ref() + .map(unpack_tolerance) + .unwrap_or(fw.tolerance), + ..fw }), - SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { - τ0 : cli.tau0.unwrap_or(sfb.τ0), - transport : override_transport(sfb.transport), - insertion : override_fb_generic(sfb.insertion), - .. sfb - }, prox), - SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { - τ0 : cli.tau0.unwrap_or(spdps.τ0), - σp0 : cli.sigmap0.unwrap_or(spdps.σp0), - σd0 : cli.sigma0.unwrap_or(spdps.σd0), - //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - transport : override_transport(spdps.transport), - insertion : override_fb_generic(spdps.insertion), - .. spdps - }, prox), - ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { - τ0 : cli.tau0.unwrap_or(fpdps.τ0), - σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), - σd0 : cli.sigma0.unwrap_or(fpdps.σd0), - //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - insertion : override_fb_generic(fpdps.insertion), - .. fpdps - }, prox), + SlidingFB(sfb, prox) => SlidingFB( + SlidingFBConfig { + τ0: cli.tau0.unwrap_or(sfb.τ0), + σp0: cli.sigmap0.unwrap_or(sfb.σp0), + transport: override_transport(sfb.transport), + insertion: override_fb_generic(sfb.insertion), + ..sfb + }, + prox, + ), + SlidingPDPS(spdps, prox) => SlidingPDPS( + SlidingPDPSConfig { + τ0: cli.tau0.unwrap_or(spdps.τ0), + σp0: cli.sigmap0.unwrap_or(spdps.σp0), + σd0: cli.sigma0.unwrap_or(spdps.σd0), + //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + transport: override_transport(spdps.transport), + insertion: override_fb_generic(spdps.insertion), + ..spdps + }, + prox, + ), + ForwardPDPS(fpdps, prox) => ForwardPDPS( + ForwardPDPSConfig { + τ0: cli.tau0.unwrap_or(fpdps.τ0), + σp0: cli.sigmap0.unwrap_or(fpdps.σp0), + σd0: cli.sigma0.unwrap_or(fpdps.σd0), + //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + insertion: override_fb_generic(fpdps.insertion), + ..fpdps + }, + prox, + ), } } } @@ -230,13 +201,14 @@ /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Named { - pub name : String, + pub name: String, #[serde(flatten)] - pub data : Data, + pub data: Data, } /// Shorthand algorithm configurations, to be used with the command line parser #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] pub enum DefaultAlgorithm { /// The μFB forward-backward method #[clap(name = "fb")] @@ -264,7 +236,6 @@ ForwardPDPS, // Radon variants - /// The μFB forward-backward method with radon-norm squared proximal term #[clap(name = "radon_fb")] RadonFB, @@ -287,70 +258,70 @@ impl DefaultAlgorithm { /// Returns the algorithm configuration corresponding to the algorithm shorthand - pub fn default_config(&self) -> AlgorithmConfig { + pub fn default_config(&self) -> AlgorithmConfig { use DefaultAlgorithm::*; - let radon_insertion = FBGenericConfig { - merging : SpikeMergingMethod{ interp : false, .. Default::default() }, - inner : InnerSettings { - method : InnerMethod::PDPS, // SSN not implemented - .. Default::default() + let radon_insertion = InsertionConfig { + merging: SpikeMergingMethod { interp: false, ..Default::default() }, + inner: InnerSettings { + method: InnerMethod::PDPS, // SSN not implemented + ..Default::default() }, - .. Default::default() + ..Default::default() }; match *self { FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), FW => AlgorithmConfig::FW(Default::default()), - FWRelax => AlgorithmConfig::FW(FWConfig{ - variant : FWVariant::Relaxed, - .. Default::default() - }), + FWRelax => { + AlgorithmConfig::FW(FWConfig { variant: FWVariant::Relaxed, ..Default::default() }) + } PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), // Radon variants - RadonFB => AlgorithmConfig::FB( - FBConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + FBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonFISTA => AlgorithmConfig::FISTA( - FBConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + FBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonPDPS => AlgorithmConfig::PDPS( - PDPSConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + PDPSConfig { generic: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonSlidingFB => AlgorithmConfig::SlidingFB( - SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + SlidingFBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( - SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + SlidingPDPSConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( - ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + ForwardPDPSConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), } } /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand - pub fn get_named(&self) -> Named> { + pub fn get_named(&self) -> Named> { self.to_named(self.default_config()) } - pub fn to_named(self, alg : AlgorithmConfig) -> Named> { - let name = self.to_possible_value().unwrap().get_name().to_string(); - Named{ name , data : alg } + pub fn to_named(self, alg: AlgorithmConfig) -> Named> { + Named { name: self.name(), data: alg } + } + + pub fn name(self) -> String { + self.to_possible_value().unwrap().get_name().to_string() } } - // // Floats cannot be hashed directly, so just hash the debug formatting // // for use as file identifier. // impl Hash for AlgorithmConfig { @@ -379,347 +350,389 @@ } } -type DefaultBT = BT< - DynamicDepth, - F, - usize, - Bounds, - N ->; -type DefaultSeminormOp = ConvolutionOp, N>; -type DefaultSG = SensorGrid::< - F, - Sensor, - Spread, - DefaultBT, - N ->; +type DefaultBT = BT, N>; +type DefaultSeminormOp = ConvolutionOp, N>; +type DefaultSG = + SensorGrid, N>; /// This is a dirty workaround to rust-csv not supporting struct flattening etc. #[derive(Serialize)] struct CSVLog { - iter : usize, - cpu_time : f64, - value : F, - relative_value : F, + iter: usize, + cpu_time: f64, + value: F, + relative_value: F, //post_value : F, - n_spikes : usize, - inner_iters : usize, - merged : usize, - pruned : usize, - this_iters : usize, + n_spikes: usize, + inner_iters: usize, + merged: usize, + pruned: usize, + this_iters: usize, + epsilon: F, } /// Collected experiment statistics #[derive(Clone, Debug, Serialize)] -struct ExperimentStats { +struct ExperimentStats { /// Signal-to-noise ratio in decibels - ssnr : F, + ssnr: F, /// Proportion of noise in the signal as a number in $[0, 1]$. - noise_ratio : F, + noise_ratio: F, /// When the experiment was run (UTC) - when : DateTime, + when: DateTime, } #[replace_float_literals(F::cast_from(literal))] -impl ExperimentStats { +impl ExperimentStats { /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. - fn new>(signal : &E, noise : &E) -> Self { + fn new>(signal: &E, noise: &E) -> Self { let s = signal.norm2_squared(); let n = noise.norm2_squared(); let noise_ratio = (n / s).sqrt(); - let ssnr = 10.0 * (s / n).log10(); - ExperimentStats { - ssnr, - noise_ratio, - when : Utc::now(), - } + let ssnr = 10.0 * (s / n).log10(); + ExperimentStats { ssnr, noise_ratio, when: Utc::now() } } } /// Collected algorithm statistics #[derive(Clone, Debug, Serialize)] -struct AlgorithmStats { +struct AlgorithmStats { /// Overall CPU time spent - cpu_time : F, + cpu_time: F, /// Real time spent - elapsed : F + elapsed: F, } - /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input /// and outputs a [`DynError`]. -fn write_json(filename : String, data : &T) -> DynError { +fn write_json(filename: String, data: &T) -> DynError { serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; Ok(()) } - /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] -pub struct ExperimentV2 -where F : Float + ClapFloat, - [usize; N] : Serialize, - NoiseDistr : Distribution, - S : Sensor, - P : Spread, - K : SimpleConvolutionKernel, +pub struct ExperimentV2 +where + F: Float + ClapFloat, + [usize; N]: Serialize, + NoiseDistr: Distribution, + S: Sensor, + P: Spread, + K: SimpleConvolutionKernel, { /// Domain $Ω$. - pub domain : Cube, + pub domain: Cube, /// Number of sensors along each dimension - pub sensor_count : [usize; N], + pub sensor_count: [usize; N], /// Noise distribution - pub noise_distr : NoiseDistr, + pub noise_distr: NoiseDistr, /// Seed for random noise generation (for repeatable experiments) - pub noise_seed : u64, + pub noise_seed: u64, /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. - pub sensor : S, + pub sensor: S, /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. - pub spread : P, + pub spread: P, /// Kernel $ρ$ of $𝒟$. - pub kernel : K, + pub kernel: K, /// True point sources - pub μ_hat : RNDM, + pub μ_hat: RNDM, /// Regularisation term and parameter - pub regularisation : Regularisation, + pub regularisation: Regularisation, /// For plotting : how wide should the kernels be plotted - pub kernel_plot_width : F, + pub kernel_plot_width: F, /// Data term - pub dataterm : DataTerm, + pub dataterm: DataTermType, /// A map of default configurations for algorithms - pub algorithm_overrides : HashMap>, + pub algorithm_overrides: HashMap>, /// Default merge radius - pub default_merge_radius : F, + pub default_merge_radius: F, } #[derive(Debug, Clone, Serialize)] -pub struct ExperimentBiased -where F : Float + ClapFloat, - [usize; N] : Serialize, - NoiseDistr : Distribution, - S : Sensor, - P : Spread, - K : SimpleConvolutionKernel, - B : Mapping, Codomain = F> + Serialize + std::fmt::Debug, +pub struct ExperimentBiased +where + F: Float + ClapFloat, + [usize; N]: Serialize, + NoiseDistr: Distribution, + S: Sensor, + P: Spread, + K: SimpleConvolutionKernel, + B: Mapping, Codomain = F> + Serialize + std::fmt::Debug, { /// Basic setup - pub base : ExperimentV2, + pub base: ExperimentV2, /// Weight of TV term - pub λ : F, + pub λ: F, /// Bias function - pub bias : B, + pub bias: B, } /// Trait for runnable experiments -pub trait RunnableExperiment { +pub trait RunnableExperiment { /// Run all algorithms provided, or default algorithms if none provided, on the experiment. - fn runall(&self, cli : &CommandLineArgs, - algs : Option>>>) -> DynError; + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option>>>, + ) -> DynError; /// Return algorithm default config - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides; -} - -/// Helper function to print experiment start message and save setup. -/// Returns saving prefix. -fn start_experiment( - experiment : &Named, - cli : &CommandLineArgs, - stats : S, -) -> DynResult -where - E : Serialize + std::fmt::Debug, - S : Serialize, -{ - let Named { name : experiment_name, data } = experiment; + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides; - println!("{}\n{}", - format!("Performing experiment {}…", experiment_name).cyan(), - format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); - - // Set up output directory - let prefix = format!("{}/{}/", cli.outdir, experiment_name); - - // Save experiment configuration and statistics - let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); - std::fs::create_dir_all(&prefix)?; - write_json(mkname_e("experiment"), experiment)?; - write_json(mkname_e("config"), cli)?; - write_json(mkname_e("stats"), &stats)?; - - Ok(prefix) + /// Experiment name + fn name(&self) -> &str; } /// Error codes for running an algorithm on an experiment. -enum RunError { +#[derive(Error, Debug)] +pub enum RunError { /// Algorithm not implemented for this experiment + #[error("Algorithm not implemented for this experiment")] NotImplemented, } use RunError::*; -type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< +type DoRunAllIt<'a, F, const N: usize> = LoggingIteratorFactory< 'a, - Timed>, - TimingIteratorFactory>> + Timed>, + TimingIteratorFactory>>, >; -/// Helper function to run all algorithms on an experiment. -fn do_runall Deserialize<'b>, Z, const N : usize>( - experiment_name : &String, - prefix : &String, - cli : &CommandLineArgs, - algorithms : Vec>>, - plotgrid : LinSpace, [usize; N]>, - mut save_extra : impl FnMut(String, Z) -> DynError, - mut do_alg : impl FnMut( - &AlgorithmConfig, - DoRunAllIt, - SeqPlotter, - String, - ) -> Result<(RNDM, Z), RunError>, -) -> DynError -where - PlotLookup : Plotting, +pub trait RunnableExperimentExtras: + RunnableExperiment + Serialize + Sized { - let mut logs = Vec::new(); + /// Helper function to print experiment start message and save setup. + /// Returns saving prefix. + fn start(&self, cli: &CommandLineArgs) -> DynResult { + let experiment_name = self.name(); + let ser = serde_json::to_string(self); - let iterator_options = AlgIteratorOptions{ - max_iter : cli.max_iter, - verbose_iter : cli.verbose_iter - .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, - |n| Verbose::Every(n)), - quiet : cli.quiet, - }; + println!( + "{}\n{}", + format!("Performing experiment {}…", experiment_name).cyan(), + format!( + "Experiment settings: {}", + if let Ok(ref s) = ser { + s + } else { + "" + } + ) + .bright_black(), + ); - // Run the algorithm(s) - for named @ Named { name : alg_name, data : alg } in algorithms.iter() { - let this_prefix = format!("{}{}/", prefix, alg_name); + // Set up output directory + let prefix = format!("{}/{}/", cli.outdir, experiment_name); + + // Save experiment configuration and statistics + std::fs::create_dir_all(&prefix)?; + write_json(format!("{prefix}experiment.json"), self)?; + write_json(format!("{prefix}config.json"), cli)?; - // Create Logger and IteratorFactory - let mut logger = Logger::new(); - let iterator = iterator_options.instantiate() - .timed() - .into_log(&mut logger); + Ok(prefix) + } - let running = if !cli.quiet { - format!("{}\n{}\n{}\n", - format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), - format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), - format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) - } else { - "".to_string() - }; - // - // The following is for postprocessing, which has been disabled anyway. - // - // let reg : Box> = match regularisation { - // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), - // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), - // }; - //let findim_data = reg.prepare_optimise_weights(&opA, &b); - //let inner_config : InnerSettings = Default::default(); - //let inner_it = inner_config.iterator_options; + /// Helper function to run all algorithms on an experiment. + fn do_runall( + &self, + prefix: &String, + cli: &CommandLineArgs, + algorithms: Vec>>, + mut make_plotter: impl FnMut(String) -> Plot, + mut save_extra: impl FnMut(String, Z) -> DynError, + init: P, + mut do_alg: impl FnMut( + (&AlgorithmConfig, DoRunAllIt, Plot, P, String), + ) -> DynResult<(RNDM, Z)>, + ) -> DynError + where + F: for<'b> Deserialize<'b>, + PlotLookup: Plotting, + P: Clone, + { + let experiment_name = self.name(); - // Create plotter and directory if needed. - let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; - let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); - - let start = Instant::now(); - let start_cpu = ProcessTime::now(); + let mut logs = Vec::new(); - let (μ, z) = match do_alg(alg, iterator, plotter, running) { - Ok(μ) => μ, - Err(RunError::NotImplemented) => { - let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ - Skipping.").red(); - eprintln!("{}", msg); - continue - } + let iterator_options = AlgIteratorOptions { + max_iter: cli.max_iter, + verbose_iter: cli + .verbose_iter + .map_or(Verbose::LogarithmicCap { base: 10, cap: 2 }, |n| { + Verbose::Every(n) + }), + quiet: cli.quiet, }; - let elapsed = start.elapsed().as_secs_f64(); - let cpu_time = start_cpu.elapsed().as_secs_f64(); + // Run the algorithm(s) + for named @ Named { name: alg_name, data: alg } in algorithms.iter() { + let this_prefix = format!("{}{}/", prefix, alg_name); - println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); + // Create Logger and IteratorFactory + let mut logger = Logger::new(); + let iterator = iterator_options.instantiate().timed().into_log(&mut logger); - // Save results - println!("{}", "Saving results …".green()); + let running = if !cli.quiet { + format!( + "{}\n{}\n{}\n", + format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), + format!( + "Iteration settings: {}", + serde_json::to_string(&iterator_options)? + ) + .bright_black(), + format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black() + ) + } else { + "".to_string() + }; + // + // The following is for postprocessing, which has been disabled anyway. + // + // let reg : Box> = match regularisation { + // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), + // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), + // }; + //let findim_data = reg.prepare_optimise_weights(&opA, &b); + //let inner_config : InnerSettings = Default::default(); + //let inner_it = inner_config.iterator_options; - let mkname = |t| format!("{prefix}{alg_name}_{t}"); + // Create plotter and directory if needed. + let plotter = make_plotter(this_prefix); + + let start = Instant::now(); + let start_cpu = ProcessTime::now(); - write_json(mkname("config.json"), &named)?; - write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; - μ.write_csv(mkname("reco.txt"))?; - save_extra(mkname(""), z)?; - //logger.write_csv(mkname("log.txt"))?; - logs.push((mkname("log.txt"), logger)); + let (μ, z) = match do_alg((alg, iterator, plotter, init.clone(), running)) { + Ok(μ) => μ, + Err(e) => { + let msg = format!( + "Skipping algorithm “{alg_name}” for {experiment_name} due to error: {e}" + ) + .red(); + eprintln!("{}", msg); + continue; + } + }; + + let elapsed = start.elapsed().as_secs_f64(); + let cpu_time = start_cpu.elapsed().as_secs_f64(); + + println!( + "{}", + format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow() + ); + + // Save results + println!("{}", "Saving results …".green()); + + let mkname = |t| format!("{prefix}{alg_name}_{t}"); + + write_json(mkname("config.json"), &named)?; + write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; + μ.write_csv(mkname("reco.txt"))?; + save_extra(mkname(""), z)?; + //logger.write_csv(mkname("log.txt"))?; + logs.push((mkname("log.txt"), logger)); + } + + save_logs( + logs, + format!("{prefix}valuerange.json"), + cli.load_valuerange, + ) } +} - save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) +impl RunnableExperimentExtras for E +where + F: ClapFloat, + Self: RunnableExperiment + Serialize, +{ } #[replace_float_literals(F::cast_from(literal))] -impl RunnableExperiment for -Named> +impl RunnableExperiment + for Named> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField - + Default + for<'b> Deserialize<'b>, - [usize; N] : Serialize, - S : Sensor + Copy + Serialize + std::fmt::Debug, - P : Spread + Copy + Serialize + std::fmt::Debug, - Convolution: Spread + Bounded + LocalAnalysis, N> + Copy - // TODO: shold not have differentiability as a requirement, but - // decide availability of sliding based on it. - //+ for<'b> Differentiable<&'b Loc, Output = Loc>, - // TODO: very weird that rust only compiles with Differentiable - // instead of the above one on references, which is required by - // poitsource_sliding_fb_reg. - + DifferentiableRealMapping - + Lipschitz, - for<'b> as DifferentiableMapping>>::Differential<'b> : Lipschitz, // TODO: should not be required generally, only for sliding_fb. - AutoConvolution

: BoundedBy, - K : SimpleConvolutionKernel + F: ClapFloat + + nalgebra::RealField + + ToNalgebraRealField + + Default + + for<'b> Deserialize<'b>, + [usize; N]: Serialize, + S: Sensor + Copy + Serialize + std::fmt::Debug, + P: Spread + Copy + Serialize + std::fmt::Debug, + Convolution: Spread + + Bounded + LocalAnalysis, N> - + Copy + Serialize + std::fmt::Debug, - Cube: P2Minimise, F> + SetOrd, - PlotLookup : Plotting, - DefaultBT : SensorGridBT + BTSearch, + + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc, Output = Loc>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + DifferentiableRealMapping + + Lipschitz, + for<'b> as DifferentiableMapping>>::Differential<'b>: + Lipschitz, // TODO: should not be required generally, only for sliding_fb. + AutoConvolution

: BoundedBy, + K: SimpleConvolutionKernel + + LocalAnalysis, N> + + Copy + + Serialize + + std::fmt::Debug, + Cube: P2Minimise, F> + SetOrd, + PlotLookup: Plotting, + DefaultBT: SensorGridBT + BTSearch, BTNodeLookup: BTNode, N>, - RNDM : SpikeMerging, - NoiseDistr : Distribution + Serialize + std::fmt::Debug, - // DefaultSG : ForwardModel, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector>, - // PreadjointCodomain : Space + Bounded + DifferentiableRealMapping, - // DefaultSeminormOp : ProxPenalty, N>, - // DefaultSeminormOp : ProxPenalty, N>, - // RadonSquared : ProxPenalty, N>, - // RadonSquared : ProxPenalty, N>, + RNDM: SpikeMerging, + NoiseDistr: Distribution + Serialize + std::fmt::Debug, { + fn name(&self) -> &str { + self.name.as_ref() + } - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides { + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides { AlgorithmOverrides { - merge_radius : Some(self.data.default_merge_radius), - .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + merge_radius: Some(self.data.default_merge_radius), + ..self + .data + .algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) } } - fn runall(&self, cli : &CommandLineArgs, - algs : Option>>>) -> DynError { + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option>>>, + ) -> DynError { // Get experiment configuration - let &Named { - name : ref experiment_name, - data : ExperimentV2 { - domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, - .. - } - } = self; + let &ExperimentV2 { + domain, + sensor_count, + ref noise_distr, + sensor, + spread, + kernel, + ref μ_hat, + regularisation, + kernel_plot_width, + dataterm, + noise_seed, + .. + } = &self.data; // Set up algorithms let algorithms = match (algs, dataterm) { (Some(algs), _) => algs, - (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], - (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], + (None, DataTermType::L222) => vec![DefaultAlgorithm::FB.get_named()], + (None, DataTermType::L1) => vec![DefaultAlgorithm::PDPS.get_named()], }; // Set up operators @@ -738,255 +751,348 @@ // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); - let prefix = start_experiment(&self, cli, stats)?; + let prefix = self.start(cli)?; + write_json(format!("{prefix}stats.json"), &stats)?; - plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, - &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + plotall( + cli, + &prefix, + &domain, + &sensor, + &kernel, + &spread, + &μ_hat, + &op𝒟, + &opA, + &b_hat, + &b, + kernel_plot_width, + )?; - let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); + let make_plotter = |this_prefix| { + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; + SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) + }; let save_extra = |_, ()| Ok(()); - do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, - |alg, iterator, plotter, running| - { - let μ = match alg { - AlgorithmConfig::FB(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented) - } - }, - AlgorithmConfig::FISTA(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::SlidingFB(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::PDPS(ref algconfig, prox) => { - print!("{running}"); - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L1 - ) - }), - // _ => Err(NotImplemented), - } - }, - AlgorithmConfig::FW(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_fw_reg(&opA, &b, RadonRegTerm(α), - algconfig, iterator, plotter) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), - algconfig, iterator, plotter) - }), - _ => Err(NotImplemented), - } - }, - _ => Err(NotImplemented), - }?; - Ok((μ, ())) - }) + let μ0 = None; // Zero init + + match (dataterm, regularisation) { + (DataTermType::L1, Regularisation::Radon(α)) => { + let f = DataTerm::new(opA, b, L1.as_mapping()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L1, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opA, b, L1.as_mapping()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L222, Regularisation::Radon(α)) => { + let f = DataTerm::new(opA, b, Norm222::new()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_fb(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_fb(&f, ®, &op𝒟, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |p| { + run_fw(&f, ®, p, |_| Err(NotImplemented.into())) + }) + }) + }) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L222, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opA, b, Norm222::new()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_fb(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_fb(&f, ®, &op𝒟, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |p| { + run_fw(&f, ®, p, |_| Err(NotImplemented.into())) + }) + }) + }) + }) + .map(|μ| (μ, ())) + }, + ) + } + } + } +} + +/// Runs PDPS if `alg` so requests and `prox_penalty` matches. +/// +/// Due to the structure of the PDPS, the data term `f` has to have a specific form. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_pdps<'a, F, A, Phi, Reg, P, I, Plot, const N: usize>( + f: &'a DataTerm, A, Phi>, + reg: &Reg, + prox_penalty: &P, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig, + I, + Plot, + Option>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig, I, Plot, Option>, String), + ) -> DynResult>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + A: ForwardModel, F>, + Phi: Conjugable, + for<'b> Phi::Conjugate<'b>: Prox, + for<'b> &'b A::Observable: Instance, + A::Observable: AXPY, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD>, + RNDM: SpikeMerging, + I: AlgIteratorFactory>, + Plot: Plotter>, +{ + match alg { + &AlgorithmConfig::PDPS(ref algconfig, prox_type) if prox_type == P::prox_type() => { + print!("{running}"); + pointsource_pdps_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), } } +/// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_fb( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig, + I, + Plot, + Option>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig, I, Plot, Option>, String), + ) -> DynResult>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F> + BoundedCurvature, + Dat::DerivativeDomain: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, + Plot: Plotter>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + &AlgorithmConfig::FISTA(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fista_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), + } +} + +/// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. +/// +/// For the moment, due to restrictions of the Frank–Wolfe implementation, only the +/// $L^2$-squared data term is enabled through the type signatures. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_fw<'a, F, A, Reg, I, Plot, const N: usize>( + f: &'a DataTerm, A, Norm222>, + reg: &Reg, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig, + I, + Plot, + Option>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig, I, Plot, Option>, String), + ) -> DynResult>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + for<'b> &'b A::PreadjointCodomain: Instance, + Cube: P2Minimise, F>, + RNDM: SpikeMerging, + Reg: RegTermFW, N>, + Plot: Plotter>, +{ + match alg { + &AlgorithmConfig::FW(ref algconfig) => { + print!("{running}"); + pointsource_fw_reg(f, reg, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), + } +} #[replace_float_literals(F::cast_from(literal))] -impl RunnableExperiment for -Named> +impl RunnableExperiment + for Named> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField - + Default + for<'b> Deserialize<'b>, - [usize; N] : Serialize, - S : Sensor + Copy + Serialize + std::fmt::Debug, - P : Spread + Copy + Serialize + std::fmt::Debug, - Convolution: Spread + Bounded + LocalAnalysis, N> + Copy - // TODO: shold not have differentiability as a requirement, but - // decide availability of sliding based on it. - //+ for<'b> Differentiable<&'b Loc, Output = Loc>, - // TODO: very weird that rust only compiles with Differentiable - // instead of the above one on references, which is required by - // poitsource_sliding_fb_reg. - + DifferentiableRealMapping - + Lipschitz, - for<'b> as DifferentiableMapping>>::Differential<'b> : Lipschitz, // TODO: should not be required generally, only for sliding_fb. - AutoConvolution

: BoundedBy, - K : SimpleConvolutionKernel + F: ClapFloat + + nalgebra::RealField + + ToNalgebraRealField + + Default + + for<'b> Deserialize<'b>, + [usize; N]: Serialize, + S: Sensor + Copy + Serialize + std::fmt::Debug, + P: Spread + Copy + Serialize + std::fmt::Debug, + Convolution: Spread + + Bounded + LocalAnalysis, N> - + Copy + Serialize + std::fmt::Debug, - Cube: P2Minimise, F> + SetOrd, - PlotLookup : Plotting, - DefaultBT : SensorGridBT + BTSearch, + + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc, Output = Loc>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + DifferentiableRealMapping + + Lipschitz, + for<'b> as DifferentiableMapping>>::Differential<'b>: + Lipschitz, // TODO: should not be required generally, only for sliding_fb. + AutoConvolution

: BoundedBy, + K: SimpleConvolutionKernel + + LocalAnalysis, N> + + Copy + + Serialize + + std::fmt::Debug, + Cube: P2Minimise, F> + SetOrd, + PlotLookup: Plotting, + DefaultBT: SensorGridBT + BTSearch, BTNodeLookup: BTNode, N>, - RNDM : SpikeMerging, - NoiseDistr : Distribution + Serialize + std::fmt::Debug, - B : Mapping, Codomain = F> + Serialize + std::fmt::Debug, - // DefaultSG : ForwardModel, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector>, - // PreadjointCodomain : Bounded + DifferentiableRealMapping, + RNDM: SpikeMerging, + NoiseDistr: Distribution + Serialize + std::fmt::Debug, + B: Mapping, Codomain = F> + Serialize + std::fmt::Debug, + nalgebra::DVector: ClosedMul, + // This is mainly required for the final Mul requirement to be defined + // DefaultSG: ForwardModel< + // RNDM, + // F, + // PreadjointCodomain = PreadjointCodomain, + // Observable = DVector, + // >, + // PreadjointCodomain: Bounded + DifferentiableRealMapping + std::ops::Mul, + // Pair>: std::ops::Mul, // DefaultSeminormOp : ProxPenalty, N>, // DefaultSeminormOp : ProxPenalty, N>, // RadonSquared : ProxPenalty, N>, // RadonSquared : ProxPenalty, N>, { + fn name(&self) -> &str { + self.name.as_ref() + } - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides { + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides { AlgorithmOverrides { - merge_radius : Some(self.data.base.default_merge_radius), - .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + merge_radius: Some(self.data.base.default_merge_radius), + ..self + .data + .base + .algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) } } - fn runall(&self, cli : &CommandLineArgs, - algs : Option>>>) -> DynError { + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option>>>, + ) -> DynError { // Get experiment configuration - let &Named { - name : ref experiment_name, - data : ExperimentBiased { - λ, - ref bias, - base : ExperimentV2 { - domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, + let &ExperimentBiased { + λ, + ref bias, + base: + ExperimentV2 { + domain, + sensor_count, + ref noise_distr, + sensor, + spread, + kernel, + ref μ_hat, + regularisation, + kernel_plot_width, + dataterm, + noise_seed, .. - } - } - } = self; + }, + } = &self.data; // Set up algorithms let algorithms = match (algs, dataterm) { @@ -1000,173 +1106,304 @@ let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); let opAext = RowOp(opA.clone(), IdOp::new()); let fnR = Zero::new(); - let h = map3(domain.span_start(), domain.span_end(), sensor_count, - |a, b, n| (b-a)/F::cast_from(n)) - .into_iter() - .reduce(NumTraitsFloat::max) - .unwrap(); + let h = map3( + domain.span_start(), + domain.span_end(), + sensor_count, + |a, b, n| (b - a) / F::cast_from(n), + ) + .into_iter() + .reduce(NumTraitsFloat::max) + .unwrap(); let z = DVector::zeros(sensor_count.iter().product()); let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); let y = opKz.apply(&z); - let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} - // let zero_y = y.clone(); - // let zeroBTFN = opA.preadjoint().apply(&zero_y); - // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); + let fnH = Weighted { base_fn: L1.as_mapping(), weight: λ }; // TODO: L_{2,1} + // let zero_y = y.clone(); + // let zeroBTFN = opA.preadjoint().apply(&zero_y); + // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); // Set up random number generator. let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic - let bias_vec = DVector::from_vec(opA.grid() - .into_iter() - .map(|v| bias.apply(v)) - .collect::>()); - let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; + let bias_vec = DVector::from_vec( + opA.grid() + .into_iter() + .map(|v| bias.apply(v)) + .collect::>(), + ); + let b_hat: DVector<_> = opA.apply(μ_hat) + &bias_vec; let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); - let prefix = start_experiment(&self, cli, stats)?; + let prefix = self.start(cli)?; + write_json(format!("{prefix}stats.json"), &stats)?; - plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, - &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + plotall( + cli, + &prefix, + &domain, + &sensor, + &kernel, + &spread, + &μ_hat, + &op𝒟, + &opA, + &b_hat, + &b, + kernel_plot_width, + )?; opA.write_observable(&bias_vec, format!("{prefix}bias"))?; - let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); + let make_plotter = |this_prefix| { + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; + SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) + }; let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); - // Run the algorithms - do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, - |alg, iterator, plotter, running| - { - let Pair(μ, z) = match alg { - AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - _ => Err(NotImplemented) - } - }, - AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - _ => Err(NotImplemented) - } - }, - _ => Err(NotImplemented) - }?; - Ok((μ, z)) - }) + let μ0 = None; // Zero init + + match (dataterm, regularisation) { + (DataTermType::L222, Regularisation::Radon(α)) => { + let f = DataTerm::new(opAext, b, Norm222::new()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z, y), + |p| { + run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { + run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { + Err(NotImplemented.into()) + }) + }) + .map(|Pair(μ, z)| (μ, z)) + }, + ) + } + (DataTermType::L222, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opAext, b, Norm222::new()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z, y), + |p| { + run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { + run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { + Err(NotImplemented.into()) + }) + }) + .map(|Pair(μ, z)| (μ, z)) + }, + ) + } + _ => Err(NotImplemented.into()), + } + } +} + +type MeasureZ = Pair, Z>; + +pub fn run_pdps_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + opKz: &KOpZ, + fnR: &R, + fnH: &H, + (alg, iterator, plotter, μ0zy, running): ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z, Y), + String, + ), + cont: impl FnOnce( + ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z, Y), + String, + ), + ) -> DynResult, Z>>, +) -> DynResult, Z>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair> + + BoundedCurvature, + S: DifferentiableRealMapping + ClosedMul, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + //Pair: ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, S, Reg, F>, + KOpZ: BoundedLinear + + GEMV + + SimplyAdjointable, + KOpZ::SimpleAdjoint: GEMV, + Y: ClosedEuclidean + Clone, + for<'b> &'b Y: Instance, + Z: ClosedEuclidean + Clone + ClosedMul, + for<'b> &'b Z: Instance, + R: Prox, + H: Conjugable, + for<'b> H::Conjugate<'b>: Prox, + Plot: Plotter>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::ForwardPDPS(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_forward_pdps_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0zy, + opKz, + fnR, + fnH, + ) + } + &AlgorithmConfig::SlidingPDPS(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_pdps_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0zy, + opKz, + fnR, + fnH, + ) + } + _ => cont((alg, iterator, plotter, μ0zy, running)), + } +} + +pub fn run_fb_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + fnR: &R, + (alg, iterator, plotter, μ0z, running): ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z), + String, + ), + cont: impl FnOnce( + ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z), + String, + ), + ) -> DynResult, Z>>, +) -> DynResult, Z>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair> + + BoundedCurvature, + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + Z: ClosedEuclidean + AXPY + Clone, + for<'b> &'b Z: Instance, + R: Prox, + Plot: Plotter>, + // We should not need to explicitly require this: + for<'b> &'b Loc<0, F>: Instance>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fb_pair(f, reg, prox_penalty, algconfig, iterator, plotter, μ0z, fnR) + } + &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_fb_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0z, + fnR, + ) + } + _ => cont((alg, iterator, plotter, μ0z, running)), } } #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -struct ValueRange { - ini : F, - min : F, +struct ValueRange { + ini: F, + min: F, } -impl ValueRange { - fn expand_with(self, other : Self) -> Self { - ValueRange { - ini : self.ini.max(other.ini), - min : self.min.min(other.min), - } +impl ValueRange { + fn expand_with(self, other: Self) -> Self { + ValueRange { ini: self.ini.max(other.ini), min: self.min.min(other.min) } } } /// Calculative minimum and maximum values of all the `logs`, and save them into /// corresponding file names given as the first elements of the tuples in the vectors. -fn save_logs Deserialize<'b>, const N : usize>( - logs : Vec<(String, Logger>>)>, - valuerange_file : String, - load_valuerange : bool, +fn save_logs Deserialize<'b>>( + logs: Vec<(String, Logger>>)>, + valuerange_file: String, + load_valuerange: bool, ) -> DynError { // Process logs for relative values println!("{}", "Processing logs…"); // Find minimum value and initial value within a single log - let proc_single_log = |log : &Logger>>| { + let proc_single_log = |log: &Logger>>| { let d = log.data(); - let mi = d.iter() - .map(|i| i.data.value) - .reduce(NumTraitsFloat::min); + let mi = d.iter().map(|i| i.data.value).reduce(NumTraitsFloat::min); d.first() - .map(|i| i.data.value) - .zip(mi) - .map(|(ini, min)| ValueRange{ ini, min }) + .map(|i| i.data.value) + .zip(mi) + .map(|(ini, min)| ValueRange { ini, min }) }; // Find minimum and maximum value over all logs - let mut v = logs.iter() - .filter_map(|&(_, ref log)| proc_single_log(log)) - .reduce(|v1, v2| v1.expand_with(v2)) - .ok_or(anyhow!("No algorithms found"))?; + let mut v = logs + .iter() + .filter_map(|&(_, ref log)| proc_single_log(log)) + .reduce(|v1, v2| v1.expand_with(v2)) + .ok_or(anyhow!("No algorithms found"))?; // Load existing range if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { @@ -1183,10 +1420,11 @@ pruned, //postprocessing, this_iters, + ε, .. } = data; // let post_value = match (postprocessing, dataterm) { - // (Some(mut μ), DataTerm::L2Squared) => { + // (Some(mut μ), DataTermType::L222) => { // // Comparison postprocessing is only implemented for the case handled // // by the FW variants. // reg.optimise_weights( @@ -1198,18 +1436,19 @@ // }, // _ => value, // }; - let relative_value = (value - v.min)/(v.ini - v.min); + let relative_value = (value - v.min) / (v.ini - v.min); CSVLog { iter, value, relative_value, //post_value, n_spikes, - cpu_time : cpu_time.as_secs_f64(), + cpu_time: cpu_time.as_secs_f64(), inner_iters, merged, pruned, - this_iters + this_iters, + epsilon: ε, } }; @@ -1224,45 +1463,48 @@ Ok(()) } - /// Plot experiment setup #[replace_float_literals(F::cast_from(literal))] -fn plotall( - cli : &CommandLineArgs, - prefix : &String, - domain : &Cube, - sensor : &Sensor, - kernel : &Kernel, - spread : &Spread, - μ_hat : &RNDM, - op𝒟 : &𝒟, - opA : &A, - b_hat : &A::Observable, - b : &A::Observable, - kernel_plot_width : F, +fn plotall( + cli: &CommandLineArgs, + prefix: &String, + domain: &Cube, + sensor: &Sensor, + kernel: &Kernel, + spread: &Spread, + μ_hat: &RNDM, + op𝒟: &𝒟, + opA: &A, + b_hat: &A::Observable, + b: &A::Observable, + kernel_plot_width: F, ) -> DynError -where F : Float + ToNalgebraRealField, - Sensor : RealMapping + Support + Clone, - Spread : RealMapping + Support + Clone, - Kernel : RealMapping + Support, - Convolution : DifferentiableRealMapping + Support, - 𝒟 : DiscreteMeasureOp, F>, - 𝒟::Codomain : RealMapping, - A : ForwardModel, F>, - for<'a> &'a A::Observable : Instance, - A::PreadjointCodomain : DifferentiableRealMapping + Bounded, - PlotLookup : Plotting, - Cube : SetOrd { - +where + F: Float + ToNalgebraRealField, + Sensor: RealMapping + Support + Clone, + Spread: RealMapping + Support + Clone, + Kernel: RealMapping + Support, + Convolution: DifferentiableRealMapping + Support, + 𝒟: DiscreteMeasureOp, F>, + 𝒟::Codomain: RealMapping, + A: ForwardModel, F>, + for<'a> &'a A::Observable: Instance, + A::PreadjointCodomain: DifferentiableRealMapping + Bounded, + PlotLookup: Plotting, + Cube: SetOrd, +{ if cli.plot < PlotLevel::Data { - return Ok(()) + return Ok(()); } let base = Convolution(sensor.clone(), spread.clone()); - let resolution = if N==1 { 100 } else { 40 }; + let resolution = if N == 1 { 100 } else { 40 }; let pfx = |n| format!("{prefix}{n}"); - let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); + let plotgrid = lingrid( + &[[-kernel_plot_width, kernel_plot_width]; N].into(), + &[resolution; N], + ); PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); @@ -1272,19 +1514,19 @@ let plotgrid2 = lingrid(&domain, &[resolution; N]); let ω_hat = op𝒟.apply(μ_hat); - let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); + let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); - let preadj_b = opA.preadjoint().apply(b); - let preadj_b_hat = opA.preadjoint().apply(b_hat); + let preadj_b = opA.preadjoint().apply(b); + let preadj_b_hat = opA.preadjoint().apply(b_hat); //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); PlotLookup::plot_into_file_spikes( Some(&preadj_b), Some(&preadj_b_hat), plotgrid2, &μ_hat, - pfx("omega_b") + pfx("omega_b"), ); PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); diff -r 9738b51d90d7 -r 4f468d35fa29 src/seminorms.rs --- a/src/seminorms.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/seminorms.rs Thu Feb 26 11:38:43 2026 -0500 @@ -6,13 +6,15 @@ use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, SpikeIter, RNDM}; use alg_tools::bisection_tree::*; +use alg_tools::bounds::Bounded; +use alg_tools::error::DynResult; use alg_tools::instance::Instance; use alg_tools::iter::{FilterMapX, Mappable}; use alg_tools::linops::{BoundedLinear, Linear, Mapping}; use alg_tools::loc::Loc; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::Linfinity; +use alg_tools::norms::{Linfinity, Norm, NormExponent}; use alg_tools::sets::Cube; use alg_tools::types::*; use itertools::Itertools; @@ -68,37 +70,37 @@ // /// A trait alias for simple convolution kernels. -pub trait SimpleConvolutionKernel: - RealMapping + Support + Bounded + Clone + 'static +pub trait SimpleConvolutionKernel: + RealMapping + Support + Bounded + Clone + 'static { } -impl SimpleConvolutionKernel for T where - T: RealMapping + Support + Bounded + Clone + 'static +impl SimpleConvolutionKernel for T where + T: RealMapping + Support + Bounded + Clone + 'static { } /// [`SupportGenerator`] for [`ConvolutionOp`]. #[derive(Clone, Debug)] -pub struct ConvolutionSupportGenerator +pub struct ConvolutionSupportGenerator where - K: SimpleConvolutionKernel, + K: SimpleConvolutionKernel, { kernel: K, - centres: RNDM, + centres: RNDM, } -impl ConvolutionSupportGenerator +impl ConvolutionSupportGenerator where - K: SimpleConvolutionKernel, + K: SimpleConvolutionKernel, { /// Construct the convolution kernel corresponding to `δ`, i.e., one centered at `δ.x` and /// weighted by `δ.α`. #[inline] fn construct_kernel<'a>( &'a self, - δ: &'a DeltaMeasure, F>, - ) -> Weighted, F> { + δ: &'a DeltaMeasure, F>, + ) -> Weighted, F> { self.kernel.clone().shift(δ.x).weigh(δ.α) } @@ -108,21 +110,21 @@ #[inline] fn construct_kernel_and_id_filtered<'a>( &'a self, - (id, δ): (usize, &'a DeltaMeasure, F>), - ) -> Option<(usize, Weighted, F>)> { + (id, δ): (usize, &'a DeltaMeasure, F>), + ) -> Option<(usize, Weighted, F>)> { (δ.α != F::ZERO).then(|| (id.into(), self.construct_kernel(δ))) } } -impl SupportGenerator for ConvolutionSupportGenerator +impl SupportGenerator for ConvolutionSupportGenerator where - K: SimpleConvolutionKernel, + K: SimpleConvolutionKernel, { type Id = usize; - type SupportType = Weighted, F>; + type SupportType = Weighted, F>; type AllDataIter<'a> = FilterMapX< 'a, - Zip, SpikeIter<'a, Loc, F>>, + Zip, SpikeIter<'a, Loc, F>>, Self, (Self::Id, Self::SupportType), >; @@ -150,13 +152,13 @@ pub struct ConvolutionOp where F: Float + ToNalgebraRealField, - BT: BTImpl, - K: SimpleConvolutionKernel, + BT: BTImpl, + K: SimpleConvolutionKernel, { /// Depth of the [`BT`] bisection tree for the outputs [`Mapping::apply`]. depth: BT::Depth, /// Domain of the [`BT`] bisection tree for the outputs [`Mapping::apply`]. - domain: Cube, + domain: Cube, /// The convolution kernel kernel: K, _phantoms: PhantomData<(F, BT)>, @@ -165,13 +167,13 @@ impl ConvolutionOp where F: Float + ToNalgebraRealField, - BT: BTImpl, - K: SimpleConvolutionKernel, + BT: BTImpl, + K: SimpleConvolutionKernel, { /// Creates a new convolution operator $𝒟$ with `kernel` on `domain`. /// /// The output of [`Mapping::apply`] is a [`BT`] of given `depth`. - pub fn new(depth: BT::Depth, domain: Cube, kernel: K) -> Self { + pub fn new(depth: BT::Depth, domain: Cube, kernel: K) -> Self { ConvolutionOp { depth: depth, domain: domain, @@ -181,7 +183,7 @@ } /// Returns the support generator for this convolution operator. - fn support_generator(&self, μ: RNDM) -> ConvolutionSupportGenerator { + fn support_generator(&self, μ: RNDM) -> ConvolutionSupportGenerator { // TODO: can we avoid cloning μ? ConvolutionSupportGenerator { kernel: self.kernel.clone(), @@ -195,18 +197,18 @@ } } -impl Mapping> for ConvolutionOp +impl Mapping> for ConvolutionOp where F: Float + ToNalgebraRealField, - BT: BTImpl, - K: SimpleConvolutionKernel, - Weighted, F>: LocalAnalysis, + BT: BTImpl, + K: SimpleConvolutionKernel, + Weighted, F>: LocalAnalysis, { - type Codomain = BTFN, BT, N>; + type Codomain = BTFN, BT, N>; fn apply(&self, μ: I) -> Self::Codomain where - I: Instance>, + I: Instance>, { let g = self.support_generator(μ.own()); BTFN::construct(self.domain.clone(), self.depth, g) @@ -214,46 +216,67 @@ } /// [`ConvolutionOp`]s as linear operators over [`DiscreteMeasure`]s. -impl Linear> for ConvolutionOp +impl Linear> for ConvolutionOp where F: Float + ToNalgebraRealField, - BT: BTImpl, - K: SimpleConvolutionKernel, - Weighted, F>: LocalAnalysis, + BT: BTImpl, + K: SimpleConvolutionKernel, + Weighted, F>: LocalAnalysis, { } -impl BoundedLinear, Radon, Linfinity, F> +impl BoundedLinear, Radon, Linfinity, F> for ConvolutionOp where F: Float + ToNalgebraRealField, - BT: BTImpl, - K: SimpleConvolutionKernel, - Weighted, F>: LocalAnalysis, + BT: BTImpl, + K: SimpleConvolutionKernel, + Weighted, F>: LocalAnalysis, { - fn opnorm_bound(&self, _: Radon, _: Linfinity) -> F { + fn opnorm_bound(&self, _: Radon, _: Linfinity) -> DynResult { // With μ = ∑_i α_i δ_{x_i}, we have // |𝒟μ|_∞ // = sup_z |∑_i α_i φ(z - x_i)| // ≤ sup_z ∑_i |α_i| |φ(z - x_i)| // ≤ ∑_i |α_i| |φ|_∞ // = |μ|_ℳ |φ|_∞ - self.kernel.bounds().uniform() + Ok(self.kernel.bounds().uniform()) } } -impl DiscreteMeasureOp, F> for ConvolutionOp +impl<'a, F, K, BT, const N: usize> NormExponent for &'a ConvolutionOp +where + F: Float + ToNalgebraRealField, + BT: BTImpl, + K: SimpleConvolutionKernel, + Weighted, F>: LocalAnalysis, +{ +} + +impl<'a, F, K, BT, const N: usize> Norm<&'a ConvolutionOp, F> for RNDM where F: Float + ToNalgebraRealField, - BT: BTImpl, - K: SimpleConvolutionKernel, - Weighted, F>: LocalAnalysis, + BT: BTImpl, + K: SimpleConvolutionKernel, + Weighted, F>: LocalAnalysis, { - type PreCodomain = PreBTFN, N>; + fn norm(&self, op𝒟: &'a ConvolutionOp) -> F { + self.apply(op𝒟.apply(self)).sqrt() + } +} + +impl DiscreteMeasureOp, F> for ConvolutionOp +where + F: Float + ToNalgebraRealField, + BT: BTImpl, + K: SimpleConvolutionKernel, + Weighted, F>: LocalAnalysis, +{ + type PreCodomain = PreBTFN, N>; fn findim_matrix<'a, I>(&self, points: I) -> DMatrix where - I: ExactSizeIterator> + Clone, + I: ExactSizeIterator> + Clone, { // TODO: Preliminary implementation. It be best to use sparse matrices or // possibly explicit operators without matrices @@ -268,7 +291,7 @@ /// A version of [`Mapping::apply`] that does not instantiate the [`BTFN`] codomain with /// a bisection tree, instead returning a [`PreBTFN`]. This can improve performance when /// the output is to be added as the right-hand-side operand to a proper BTFN. - fn preapply(&self, μ: RNDM) -> Self::PreCodomain { + fn preapply(&self, μ: RNDM) -> Self::PreCodomain { BTFN::new_pre(self.support_generator(μ)) } } @@ -277,27 +300,27 @@ /// for [`ConvolutionSupportGenerator`]. macro_rules! make_convolutionsupportgenerator_scalarop_rhs { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl, const N: usize> std::ops::$trait_assign - for ConvolutionSupportGenerator + impl, const N: usize> std::ops::$trait_assign + for ConvolutionSupportGenerator { fn $fn_assign(&mut self, t: F) { self.centres.$fn_assign(t); } } - impl, const N: usize> std::ops::$trait - for ConvolutionSupportGenerator + impl, const N: usize> std::ops::$trait + for ConvolutionSupportGenerator { - type Output = ConvolutionSupportGenerator; + type Output = ConvolutionSupportGenerator; fn $fn(mut self, t: F) -> Self::Output { std::ops::$trait_assign::$fn_assign(&mut self.centres, t); self } } - impl<'a, F: Float, K: SimpleConvolutionKernel, const N: usize> std::ops::$trait - for &'a ConvolutionSupportGenerator + impl<'a, F: Float, K: SimpleConvolutionKernel, const N: usize> std::ops::$trait + for &'a ConvolutionSupportGenerator { - type Output = ConvolutionSupportGenerator; + type Output = ConvolutionSupportGenerator; fn $fn(self, t: F) -> Self::Output { ConvolutionSupportGenerator { kernel: self.kernel.clone(), @@ -314,20 +337,20 @@ /// Generates an unary operation (e.g. [`std::ops::Neg`]) for [`ConvolutionSupportGenerator`]. macro_rules! make_convolutionsupportgenerator_unaryop { ($trait:ident, $fn:ident) => { - impl, const N: usize> std::ops::$trait - for ConvolutionSupportGenerator + impl, const N: usize> std::ops::$trait + for ConvolutionSupportGenerator { - type Output = ConvolutionSupportGenerator; + type Output = ConvolutionSupportGenerator; fn $fn(mut self) -> Self::Output { self.centres = self.centres.$fn(); self } } - impl<'a, F: Float, K: SimpleConvolutionKernel, const N: usize> std::ops::$trait - for &'a ConvolutionSupportGenerator + impl<'a, F: Float, K: SimpleConvolutionKernel, const N: usize> std::ops::$trait + for &'a ConvolutionSupportGenerator { - type Output = ConvolutionSupportGenerator; + type Output = ConvolutionSupportGenerator; fn $fn(self) -> Self::Output { ConvolutionSupportGenerator { kernel: self.kernel.clone(), diff -r 9738b51d90d7 -r 4f468d35fa29 src/sliding_fb.rs --- a/src/sliding_fb.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/sliding_fb.rs Thu Feb 26 11:38:43 2026 -0500 @@ -10,22 +10,21 @@ use itertools::izip; use std::iter::Iterator; +use crate::fb::*; +use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::plot::Plotter; +use crate::prox_penalty::{ProxPenalty, StepLengthBound}; +use crate::regularisation::SlidingRegTerm; +use crate::types::*; +use alg_tools::error::DynResult; use alg_tools::euclidean::Euclidean; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; +use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::norms::Norm; - -use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel}; -use crate::measures::merging::SpikeMerging; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::types::*; -//use crate::tolerance::Tolerance; -use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared}; -use crate::fb::*; -use crate::plot::{PlotLookup, Plotting, SeqPlotter}; -use crate::regularisation::SlidingRegTerm; -//use crate::transport::TransportLipschitz; +use anyhow::ensure; /// Transport settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] @@ -42,21 +41,18 @@ #[replace_float_literals(F::cast_from(literal))] impl TransportConfig { /// Check that the parameters are ok. Panics if not. - pub fn check(&self) { - assert!(self.θ0 > 0.0); - assert!(0.0 < self.adaptation && self.adaptation < 1.0); - assert!(self.tolerance_mult_con > 0.0); + pub fn check(&self) -> DynResult<()> { + ensure!(self.θ0 > 0.0); + ensure!(0.0 < self.adaptation && self.adaptation < 1.0); + ensure!(self.tolerance_mult_con > 0.0); + Ok(()) } } #[replace_float_literals(F::cast_from(literal))] impl Default for TransportConfig { fn default() -> Self { - TransportConfig { - θ0: 0.9, - adaptation: 0.9, - tolerance_mult_con: 100.0, - } + TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0 } } } @@ -66,10 +62,14 @@ pub struct SlidingFBConfig { /// Step length scaling pub τ0: F, + // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`] + pub σp0: F, /// Transport parameters pub transport: TransportConfig, /// Generic parameters - pub insertion: FBGenericConfig, + pub insertion: InsertionConfig, + /// Guess for curvature bound calculations. + pub guess: BoundedCurvatureGuess, } #[replace_float_literals(F::cast_from(literal))] @@ -77,8 +77,10 @@ fn default() -> Self { SlidingFBConfig { τ0: 0.99, + σp0: 0.99, transport: Default::default(), insertion: Default::default(), + guess: BoundedCurvatureGuess::BetterThanZero, } } } @@ -100,16 +102,16 @@ /// with step lengh τ and transport step length `θ_or_adaptive`. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn initial_transport( - γ1: &mut RNDM, - μ: &mut RNDM, + γ1: &mut RNDM, + μ: &mut RNDM, τ: F, θ_or_adaptive: &mut TransportStepLength, v: D, -) -> (Vec, RNDM) +) -> (Vec, RNDM) where F: Float + ToNalgebraRealField, G: Fn(F, F) -> F, - D: DifferentiableRealMapping, + D: DifferentiableRealMapping, { use TransportStepLength::*; @@ -145,22 +147,14 @@ ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); } } - AdaptiveMax { - l: ℓ_F, - ref mut max_transport, - g: ref calculate_θ, - } => { + AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θ } => { *max_transport = max_transport.max(γ1.norm(Radon)); let θτ = τ * calculate_θ(ℓ_F, *max_transport); for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); } } - FullyAdaptive { - l: ref mut adaptive_ℓ_F, - ref mut max_transport, - g: ref calculate_θ, - } => { + FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θ } => { *max_transport = max_transport.max(γ1.norm(Radon)); let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); // Do two runs through the spikes to update θ, breaking if first run did not cause @@ -209,9 +203,9 @@ /// A posteriori transport adaptation. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn aposteriori_transport( - γ1: &mut RNDM, - μ: &mut RNDM, - μ_base_minus_γ0: &mut RNDM, + γ1: &mut RNDM, + μ: &mut RNDM, + μ_base_minus_γ0: &mut RNDM, μ_base_masses: &Vec, extra: Option, ε: F, @@ -264,36 +258,33 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg( - opA: &A, - b: &A::Observable, - reg: Reg, +pub fn pointsource_sliding_fb_reg( + f: &Dat, + reg: &Reg, prox_penalty: &P, config: &SlidingFBConfig, iterator: I, - mut plotter: SeqPlotter, -) -> RNDM + mut plotter: Plot, + μ0: Option>, +) -> DynResult> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory>, - A: ForwardModel, F> - + AdjointProductBoundedBy, P, FloatType = F> - + BoundedCurvature, - for<'b> &'b A::Observable: std::ops::Neg + Instance, - A::PreadjointCodomain: DifferentiableRealMapping, - RNDM: SpikeMerging, - Reg: SlidingRegTerm, - P: ProxPenalty, - PlotLookup: Plotting, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F> + BoundedCurvature, + Dat::DerivativeDomain: DifferentiableRealMapping + ClosedMul, + //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, + Plot: Plotter>, { // Check parameters - assert!(config.τ0 > 0.0, "Invalid step length parameter"); - config.transport.check(); + ensure!(config.τ0 > 0.0, "Invalid step length parameter"); + config.transport.check()?; // Initialise iterates - let mut μ = DiscreteMeasure::new(); + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); let mut γ1 = DiscreteMeasure::new(); - let mut residual = -b; // Has to equal $Aμ-b$. // Set up parameters // let opAnorm = opA.opnorm_bound(Radon, L2); @@ -301,21 +292,21 @@ // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); - let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); - let transport_lip = maybe_transport_lip.unwrap(); + let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; + let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); + let transport_lip = maybe_transport_lip?; let calculate_θ = |ℓ_F, max_transport| { let ℓ_r = transport_lip * max_transport; config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) }; - let mut θ_or_adaptive = match maybe_ℓ_F0 { + let mut θ_or_adaptive = match maybe_ℓ_F { //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), - Some(ℓ_F0) => TransportStepLength::AdaptiveMax { - l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual + Ok(ℓ_F) => TransportStepLength::AdaptiveMax { + l: ℓ_F, // TODO: could estimate computing the real reesidual max_transport: 0.0, g: calculate_θ, }, - None => TransportStepLength::FullyAdaptive { + Err(_) => TransportStepLength::FullyAdaptive { l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials max_transport: 0.0, g: calculate_θ, @@ -327,8 +318,8 @@ let mut ε = tolerance.initial(); // Statistics - let full_stats = |residual: &A::Observable, μ: &RNDM, ε, stats| IterInfo { - value: residual.norm2_squared_div2() + reg.apply(μ), + let full_stats = |μ: &RNDM, ε, stats| IterInfo { + value: f.apply(μ) + reg.apply(μ), n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), @@ -337,9 +328,9 @@ let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate initial transport - let v = opA.preadjoint().apply(residual); + let v = f.differential(&μ); let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); @@ -347,8 +338,11 @@ // regularisation term conforms to the assumptions made for the transport above. let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) - let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); - let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); + //let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); + // TODO: this could be optimised by doing the differential like the + // old residual2. + let μ̆ = &γ1 + &μ_base_minus_γ0; + let mut τv̆ = f.differential(μ̆) * τ; // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( @@ -362,7 +356,7 @@ ®, &state, &mut stats, - ); + )?; // A posteriori transport adaptation. if aposteriori_transport( @@ -404,7 +398,7 @@ ε, ins, ®, - Some(|μ̃: &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), + Some(|μ̃: &RNDM| f.apply(μ̃)), ); } @@ -419,26 +413,19 @@ μ = μ_new; } - // Update residual - residual = calculate_residual(&μ, opA, b); - let iter = state.iteration(); stats.this_iters += 1; // Give statistics if requested state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); - full_stats( - &residual, - &μ, - ε, - std::mem::replace(&mut stats, IterInfo::new()), - ) + full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - postprocess(μ, &config.insertion, L2Squared, opA, b) + //postprocess(μ, &config.insertion, f) + postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃)) } diff -r 9738b51d90d7 -r 4f468d35fa29 src/sliding_pdps.rs --- a/src/sliding_pdps.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/sliding_pdps.rs Thu Feb 26 11:38:43 2026 -0500 @@ -3,51 +3,53 @@ primal-dual proximal splitting method. */ +use crate::fb::*; +use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::plot::Plotter; +use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; +use crate::regularisation::SlidingRegTerm; +use crate::sliding_fb::{ + aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength, +}; +use crate::types::*; +use alg_tools::convex::{Conjugable, Prox, Zero}; +use alg_tools::direct_product::Pair; +use alg_tools::error::DynResult; +use alg_tools::euclidean::ClosedEuclidean; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::linops::{ + BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV, +}; +use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::{Norm, L2}; +use anyhow::ensure; use numeric_literals::replace_float_literals; use serde::{Deserialize, Serialize}; //use colored::Colorize; //use nalgebra::{DVector, DMatrix}; use std::iter::Iterator; -use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::direct_product::Pair; -use alg_tools::euclidean::Euclidean; -use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV}; -use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::{Dist, Norm}; -use alg_tools::norms::{PairNorm, L2}; - -use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; -use crate::measures::merging::SpikeMerging; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::types::*; -// use crate::transport::TransportLipschitz; -//use crate::tolerance::Tolerance; -use crate::fb::*; -use crate::plot::{PlotLookup, Plotting, SeqPlotter}; -use crate::regularisation::SlidingRegTerm; -// use crate::dataterm::L2Squared; -use crate::dataterm::{calculate_residual, calculate_residual2}; -use crate::sliding_fb::{ - aposteriori_transport, initial_transport, TransportConfig, TransportStepLength, -}; - /// Settings for [`pointsource_sliding_pdps_pair`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct SlidingPDPSConfig { - /// Primal step length scaling. + /// Overall primal step length scaling. pub τ0: F, - /// Primal step length scaling. + /// Primal step length scaling for additional variable. pub σp0: F, - /// Dual step length scaling. + /// Dual step length scaling for additional variable. + /// + /// Taken zero for [`pointsource_sliding_fb_pair`]. pub σd0: F, /// Transport parameters pub transport: TransportConfig, /// Generic parameters - pub insertion: FBGenericConfig, + pub insertion: InsertionConfig, + /// Guess for curvature bound calculations. + pub guess: BoundedCurvatureGuess, } #[replace_float_literals(F::cast_from(literal))] @@ -57,16 +59,14 @@ τ0: 0.99, σd0: 0.05, σp0: 0.99, - transport: TransportConfig { - θ0: 0.9, - ..Default::default() - }, + transport: TransportConfig { θ0: 0.9, ..Default::default() }, insertion: Default::default(), + guess: BoundedCurvatureGuess::BetterThanZero, } } } -type MeasureZ = Pair, Z>; +type MeasureZ = Pair, Z>; /// Iteratively solve the pointsource localisation with an additional variable /// using sliding primal-dual proximal splitting @@ -76,67 +76,66 @@ pub fn pointsource_sliding_pdps_pair< F, I, - A, S, + Dat, Reg, P, Z, R, Y, + Plot, /*KOpM, */ KOpZ, H, const N: usize, >( - opA: &A, - b: &A::Observable, - reg: Reg, + f: &Dat, + reg: &Reg, prox_penalty: &P, config: &SlidingPDPSConfig, iterator: I, - mut plotter: SeqPlotter, + mut plotter: Plot, + (μ0, mut z, mut y): (Option>, Z, Y), //opKμ : KOpM, opKz: &KOpZ, fnR: &R, fnH: &H, - mut z: Z, - mut y: Y, -) -> MeasureZ +) -> DynResult> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory>, - A: ForwardModel, F, PairNorm, PreadjointCodomain = Pair> - + AdjointProductPairBoundedBy, P, IdOp, FloatType = F> - + BoundedCurvature, - S: DifferentiableRealMapping, - for<'b> &'b A::Observable: std::ops::Neg + Instance, - PlotLookup: Plotting, - RNDM: SpikeMerging, - Reg: SlidingRegTerm, - P: ProxPenalty, - // KOpM : Linear, Codomain=Y> - // + GEMV> + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair> + + BoundedCurvature, + S: DifferentiableRealMapping + ClosedMul, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + //Pair: ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, S, Reg, F>, + // KOpM : Linear, Codomain=Y> + // + GEMV> // + Preadjointable< - // RNDM, Y, + // RNDM, Y, // PreadjointCodomain = S, // > // + TransportLipschitz - // + AdjointProductBoundedBy, 𝒟, FloatType=F>, + // + AdjointProductBoundedBy, 𝒟, FloatType=F>, // for<'b> KOpM::Preadjoint<'b> : GEMV, // Since Z is Hilbert, we may just as well use adjoints for K_z. KOpZ: BoundedLinear + GEMV - + Adjointable, - for<'b> KOpZ::Adjoint<'b>: GEMV, - Y: AXPY + Euclidean + Clone + ClosedAdd, + + SimplyAdjointable, + KOpZ::SimpleAdjoint: GEMV, + Y: ClosedEuclidean, for<'b> &'b Y: Instance, - Z: AXPY + Euclidean + Clone + Norm + Dist, + Z: ClosedEuclidean, for<'b> &'b Z: Instance, R: Prox, H: Conjugable, for<'b> H::Conjugate<'b>: Prox, + Plot: Plotter>, { // Check parameters - assert!( + /*ensure!( config.τ0 > 0.0 && config.τ0 < 1.0 && config.σp0 > 0.0 @@ -144,26 +143,25 @@ && config.σd0 > 0.0 && config.σp0 * config.σd0 <= 1.0, "Invalid step length parameters" - ); - config.transport.check(); + );*/ + config.transport.check()?; // Initialise iterates - let mut μ = DiscreteMeasure::new(); + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); let mut γ1 = DiscreteMeasure::new(); - let mut residual = calculate_residual(Pair(&μ, &z), opA, b); - let zero_z = z.similar_origin(); + //let zero_z = z.similar_origin(); // Set up parameters // TODO: maybe this PairNorm doesn't make sense here? // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); - let nKz = opKz.opnorm_bound(L2, L2); + let nKz = opKz.opnorm_bound(L2, L2)?; let ℓ = 0.0; - let opIdZ = IdOp::new(); - let (l, l_z) = opA - .adjoint_product_pair_bound(prox_penalty, &opIdZ) - .unwrap(); + let idOpZ = IdOp::new(); + let opKz_adj = opKz.adjoint(); + let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?; + // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -172,7 +170,8 @@ // // To do so, we first solve σ_p and σ_d from standard PDPS step length condition // ^^^^^ < 1. then we solve τ from the rest. - let σ_d = config.σd0 / nKz; + // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below. + let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz }; let σ_p = config.σp0 / (l_z + config.σd0 * nKz); // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) @@ -182,29 +181,29 @@ let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); let ψ = 1.0 - τ * l; let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; - assert!(β < 1.0); + ensure!(β < 1.0); // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); // The factor two in the manuscript disappears due to the definition of 𝚹 being // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. - let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); - let transport_lip = maybe_transport_lip.unwrap(); + let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); + let transport_lip = maybe_transport_lip?; let calculate_θ = |ℓ_F, max_transport| { let ℓ_r = transport_lip * max_transport; config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) }; - let mut θ_or_adaptive = match maybe_ℓ_F0 { + let mut θ_or_adaptive = match maybe_ℓ_F { // We assume that the residual is decreasing. - Some(ℓ_F0) => TransportStepLength::AdaptiveMax { - l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual + Ok(ℓ_F) => TransportStepLength::AdaptiveMax { + l: ℓ_F, // TODO: could estimate computing the real reesidual max_transport: 0.0, g: calculate_θ, }, - None => TransportStepLength::FullyAdaptive { - l: F::EPSILON, - max_transport: 0.0, - g: calculate_θ, - }, + Err(_) => { + TransportStepLength::FullyAdaptive { + l: F::EPSILON, max_transport: 0.0, g: calculate_θ + } + } }; // Acceleration is not currently supported // let γ = dataterm.factor_of_strong_convexity(); @@ -218,8 +217,8 @@ let starH = fnH.conjugate(); // Statistics - let full_stats = |residual: &A::Observable, μ: &RNDM, z: &Z, ε, stats| IterInfo { - value: residual.norm2_squared_div2() + let full_stats = |μ: &RNDM, z: &Z, ε, stats| IterInfo { + value: f.apply(Pair(μ, z)) + fnR.apply(z) + reg.apply(μ) + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), @@ -231,9 +230,9 @@ let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) { // Calculate initial transport - let Pair(v, _) = opA.preadjoint().apply(&residual); + let Pair(v, _) = f.differential(Pair(&μ, &z)); //opKμ.preadjoint().apply_add(&mut v, y); // We want to proceed as in Example 4.12 but with v and v̆ as in §5. // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have @@ -242,6 +241,8 @@ // This is much easier with K_μ = 0, which is the only reason why are enforcing it. // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. + //dbg!(&μ); + let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); @@ -249,9 +250,11 @@ // regularisation term conforms to the assumptions made for the transport above. let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) - let residual_μ̆ = - calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); - let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); + // let residual_μ̆ = + // calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); + // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); + // TODO: might be able to optimise the measure sum working as calculate_residual2 above. + let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ; // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. @@ -266,11 +269,11 @@ ®, &state, &mut stats, - ); + )?; // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} let mut z_new = τz̆; - opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p / τ); + opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); z_new = fnR.prox(σ_p, z_new + &z); // A posteriori transport adaptation. @@ -279,7 +282,7 @@ &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, - Some(z_new.dist(&z, L2)), + Some(z_new.dist2(&z)), ε, &config.transport, ) { @@ -313,7 +316,7 @@ ε, ins, ®, - //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), + //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), ); } @@ -336,9 +339,6 @@ y = starH.prox(σ_d, y); z = z_new; - // Update residual - residual = calculate_residual(Pair(&μ, &z), opA, b); - // Update step length parameters // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); @@ -348,26 +348,78 @@ state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); - full_stats( - &residual, - &μ, - &z, - ε, - std::mem::replace(&mut stats, IterInfo::new()), - ) + full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - let fit = |μ̃: &RNDM| { - (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() - //+ fnR.apply(z) + reg.apply(μ) + let fit = |μ̃: &RNDM| { + f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/ + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) }; μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); μ.prune(); - Pair(μ, z) + Ok(Pair(μ, z)) } + +/// Iteratively solve the pointsource localisation with an additional variable +/// using sliding forward-backward splitting. +/// +/// The implementation uses [`pointsource_sliding_pdps_pair`] with appropriate dummy +/// variables, operators, and functions. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_sliding_fb_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + config: &SlidingFBConfig, + iterator: I, + plotter: Plot, + (μ0, z): (Option>, Z), + //opKμ : KOpM, + fnR: &R, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair> + + BoundedCurvature, + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + Z: ClosedEuclidean + AXPY + Clone, + for<'b> &'b Z: Instance, + R: Prox, + Plot: Plotter>, + // We should not need to explicitly require this: + for<'b> &'b Loc<0, F>: Instance>, + // Loc<0, F>: StaticEuclidean> + // + Instance> + // + VectorSpace, +{ + let opKz: ZeroOp, _, _, F> = + ZeroOp::new_dualisable(StaticEuclideanOriginGenerator, z.dual_origin()); + let fnH = Zero::new(); + // Convert config. We don't implement From (that could be done with the o2o crate), as σd0 + // needs to be chosen in a general case; for the problem of this fucntion, anything is valid. + let &SlidingFBConfig { τ0, σp0, insertion, transport, guess } = config; + let pdps_config = SlidingPDPSConfig { τ0, σp0, insertion, transport, guess, σd0: 0.0 }; + + pointsource_sliding_pdps_pair( + f, + reg, + prox_penalty, + &pdps_config, + iterator, + plotter, + (μ0, z, Loc([])), + &opKz, + fnR, + &fnH, + ) +} diff -r 9738b51d90d7 -r 4f468d35fa29 src/tolerance.rs --- a/src/tolerance.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/tolerance.rs Thu Feb 26 11:38:43 2026 -0500 @@ -1,33 +1,33 @@ //! Tolerance update schemes for subproblem solution quality -use serde::{Serialize, Deserialize}; +use crate::types::*; use numeric_literals::replace_float_literals; -use crate::types::*; +use serde::{Deserialize, Serialize}; /// Update style for optimality system solution tolerance #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[allow(dead_code)] -pub enum Tolerance { +pub enum Tolerance { /// $ε_k = εθ^k$ for the `factor` $θ$ and initial tolerance $ε$. - Exponential{ factor : F, initial : F }, + Exponential { factor: F, initial: F }, /// $ε_k = ε/(1+θk)^p$ for the `factor` $θ$, `exponent` $p$, and initial tolerance $ε$. - Power{ factor : F, exponent : F, initial : F}, + Power { factor: F, exponent: F, initial: F }, /// $ε_k = εθ^{⌊k^p⌋}$ for the `factor` $θ$, initial tolerance $ε$, and exponent $p$. - SlowExp{ factor : F, exponent : F, initial : F } + SlowExp { factor: F, exponent: F, initial: F }, } #[replace_float_literals(F::cast_from(literal))] -impl Default for Tolerance { +impl Default for Tolerance { fn default() -> Self { Tolerance::Power { - initial : 0.5, - factor : 0.2, - exponent : 1.4 // 1.5 works but is already slower in practise on our examples. + initial: 0.5, + factor: 0.2, + exponent: 1.4, // 1.5 works but is already slower in practise on our examples. } } } #[replace_float_literals(F::cast_from(literal))] -impl Tolerance { +impl Tolerance { /// Get the initial tolerance pub fn initial(&self) -> F { match self { @@ -47,21 +47,19 @@ } /// Set the initial tolerance - pub fn set_initial(&mut self, set : F) { + pub fn set_initial(&mut self, set: F) { *self.initial_mut() = set; } /// Update `tolerance` for iteration `iter`. /// `tolerance` may or may not be used depending on the specific /// update scheme. - pub fn update(&self, tolerance : F, iter : usize) -> F { + pub fn update(&self, tolerance: F, iter: usize) -> F { match self { - &Tolerance::Exponential { factor, .. } => { - tolerance * factor - }, + &Tolerance::Exponential { factor, .. } => tolerance * factor, &Tolerance::Power { factor, exponent, initial } => { - initial /(1.0 + factor * F::cast_from(iter)).powf(exponent) - }, + initial / (1.0 + factor * F::cast_from(iter)).powf(exponent) + } &Tolerance::SlowExp { factor, exponent, initial } => { // let m = (speed // * factor.powi(-(iter as i32)) @@ -69,20 +67,20 @@ // ).floor().as_(); let m = F::cast_from(iter).powf(exponent).floor().as_(); initial * factor.powi(m) - }, + } } } } impl std::ops::MulAssign for Tolerance { - fn mul_assign(&mut self, factor : F) { + fn mul_assign(&mut self, factor: F) { *self.initial_mut() *= factor; } } impl std::ops::Mul for Tolerance { type Output = Tolerance; - fn mul(mut self, factor : F) -> Self::Output { + fn mul(mut self, factor: F) -> Self::Output { *self.initial_mut() *= factor; self } diff -r 9738b51d90d7 -r 4f468d35fa29 src/types.rs --- a/src/types.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/types.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,159 +2,146 @@ use numeric_literals::replace_float_literals; +use alg_tools::iterate::LogRepr; use colored::ColoredString; -use serde::{Serialize, Deserialize}; -use alg_tools::iterate::LogRepr; -use alg_tools::euclidean::Euclidean; -use alg_tools::norms::{Norm, L1}; +use serde::{Deserialize, Serialize}; -pub use alg_tools::types::*; +pub use alg_tools::error::DynResult; pub use alg_tools::loc::Loc; pub use alg_tools::sets::Cube; +pub use alg_tools::types::*; // use crate::measures::DiscreteMeasure; /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. -pub trait ClapFloat : Float - + std::str::FromStr - + std::fmt::Display {} +pub trait ClapFloat: + Float + std::str::FromStr + std::fmt::Display +{ +} impl ClapFloat for f32 {} impl ClapFloat for f64 {} /// Structure for storing iteration statistics #[derive(Debug, Clone, Serialize)] -pub struct IterInfo { +pub struct IterInfo { /// Function value - pub value : F, + pub value: F, /// Number of spikes - pub n_spikes : usize, + pub n_spikes: usize, /// Number of iterations this statistic covers - pub this_iters : usize, + pub this_iters: usize, /// Number of spikes inserted since last IterInfo statistic - pub inserted : usize, + pub inserted: usize, /// Number of spikes removed by merging since last IterInfo statistic - pub merged : usize, + pub merged: usize, /// Number of spikes removed by pruning since last IterInfo statistic - pub pruned : usize, + pub pruned: usize, /// Number of inner iterations since last IterInfo statistic - pub inner_iters : usize, + pub inner_iters: usize, /// Tuple of (transported mass, source mass) - pub untransported_fraction : Option<(F, F)>, + pub untransported_fraction: Option<(F, F)>, /// Tuple of (|destination mass - untransported_mass|, transported mass) - pub transport_error : Option<(F, F)>, + pub transport_error: Option<(F, F)>, /// Current tolerance - pub ε : F, + pub ε: F, // /// Solve fin.dim problem for this measure to get the optimal `value`. - // pub postprocessing : Option>, + // pub postprocessing : Option>, } -impl IterInfo { +impl IterInfo { /// Initialise statistics with zeros. `ε` and `value` are unspecified. pub fn new() -> Self { IterInfo { - value : F::NAN, - n_spikes : 0, - this_iters : 0, - merged : 0, - inserted : 0, - pruned : 0, - inner_iters : 0, - ε : F::NAN, + value: F::NAN, + n_spikes: 0, + this_iters: 0, + merged: 0, + inserted: 0, + pruned: 0, + inner_iters: 0, + ε: F::NAN, // postprocessing : None, - untransported_fraction : None, - transport_error : None, + untransported_fraction: None, + transport_error: None, } } } #[replace_float_literals(F::cast_from(literal))] -impl LogRepr for IterInfo where F : LogRepr + Float { +impl LogRepr for IterInfo +where + F: LogRepr + Float, +{ fn logrepr(&self) -> ColoredString { - format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", - self.value.logrepr(), - self.n_spikes, - self.ε, - self.inner_iters as float / self.this_iters.max(1) as float, - self.inserted as float / self.this_iters.max(1) as float, - self.merged as float / self.this_iters.max(1) as float, - self.pruned as float / self.this_iters.max(1) as float, - match self.untransported_fraction { - None => format!(""), - Some((a, b)) => if b > 0.0 { - format!(", untransported {:.2}%", 100.0*a/b) + format!( + "{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", + self.value.logrepr(), + self.n_spikes, + self.ε, + self.inner_iters as float / self.this_iters.max(1) as float, + self.inserted as float / self.this_iters.max(1) as float, + self.merged as float / self.this_iters.max(1) as float, + self.pruned as float / self.this_iters.max(1) as float, + match self.untransported_fraction { + None => format!(""), + Some((a, b)) => + if b > 0.0 { + format!(", untransported {:.2}%", 100.0 * a / b) } else { format!("") - } - }, - match self.transport_error { - None => format!(""), - Some((a, b)) => if b > 0.0 { - format!(", transport error {:.2}%", 100.0*a/b) + }, + }, + match self.transport_error { + None => format!(""), + Some((a, b)) => + if b > 0.0 { + format!(", transport error {:.2}%", 100.0 * a / b) } else { format!("") - } - } - ).as_str().into() + }, + } + ) + .as_str() + .into() } } /// Branch and bound refinement settings #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct RefinementSettings { +pub struct RefinementSettings { /// Function value tolerance multiplier for bisection tree refinement in /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. - pub tolerance_mult : F, + pub tolerance_mult: F, /// Maximum branch and bound steps - pub max_steps : usize, + pub max_steps: usize, } #[replace_float_literals(F::cast_from(literal))] -impl Default for RefinementSettings { +impl Default for RefinementSettings { fn default() -> Self { RefinementSettings { - tolerance_mult : 0.1, - max_steps : 50000, + tolerance_mult: 0.1, + max_steps: 50000, } } } /// Data term type #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] -pub enum DataTerm { +pub enum DataTermType { /// $\\|z\\|\_2^2/2$ - L2Squared, + L222, /// $\\|z\\|\_1$ L1, } -impl DataTerm { - /// Calculate the data term value at residual $z=Aμ - b$. - pub fn value_at_residual + Norm>(&self, z : E) -> F { - match self { - Self::L2Squared => z.norm2_squared_div2(), - Self::L1 => z.norm(L1), - } - } -} - -/// Type for indicating norm-2-squared data fidelity or transport cost. -#[derive(Clone, Copy, Serialize, Deserialize)] -pub struct L2Squared; - -/// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. -pub trait Lipschitz { - /// The type of floats - type FloatType : Float; - - /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. - fn lipschitz_factor(&self, seminorm : M) -> Option; -} +pub use alg_tools::mapping::Lipschitz; /// Trait for norm-bounded functions. pub trait NormBounded { - type FloatType : Float; + type FloatType: Float; /// Returns a bound on the values of this function object in the `M`-norm. - fn norm_bound(&self, m : M) -> Self::FloatType; + fn norm_bound(&self, m: M) -> Self::FloatType; }