Denoising routine for testing + visualisation tools

Thu, 21 Nov 2019 18:44:27 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 21 Nov 2019 18:44:27 -0500
changeset 9
1cffd3d07fe2
parent 8
c5aabb2c41d9
child 10
fdf2f44be973

Denoising routine for testing + visualisation tools

Project.toml file | annotate | diff | comparison | revisions
src/Denoise.jl file | annotate | diff | comparison | revisions
src/ImageTools.jl file | annotate | diff | comparison | revisions
src/Visualise.jl file | annotate | diff | comparison | revisions
test/denoise.jl file | annotate | diff | comparison | revisions
--- a/Project.toml	Tue Nov 19 10:15:38 2019 -0500
+++ b/Project.toml	Thu Nov 21 18:44:27 2019 -0500
@@ -2,3 +2,13 @@
 uuid = "b548cc0d-4ade-417e-bf62-0e39f9d2eee9"
 authors = ["Tuomo Valkonen <tuomov@iki.fi>"]
 version = "0.1.0"
+
+[deps]
+AlgTools = "c46e2e78-5339-41fd-a966-983ff60ab8e7"
+Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
+Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
+Measures = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
+Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Denoise.jl	Thu Nov 21 18:44:27 2019 -0500
@@ -0,0 +1,116 @@
+########################################################
+# Basic TV denoising via primal–dual proximal splitting
+########################################################
+
+__precompile__()
+
+module Denoise
+
+using AlgTools.Util
+import AlgTools.Iterate
+using ImageTools.Gradient
+
+##############
+# Our exports
+##############
+
+export denoise_pdps
+
+#############
+# Data types
+#############
+
+ImageSize = Tuple{Integer,Integer}
+Image = Array{Float64,2}
+Primal = Image
+Dual = Array{Float64,3}
+
+#########################
+# Iterate initialisation
+#########################
+
+function init_rest(x::Primal)
+    imdim=size(x)
+
+    y = zeros(2, imdim...)
+    Δx = copy(x)
+    Δy = copy(y)
+    x̄ = copy(x)
+
+    return x, y, Δx, Δy, x̄
+end
+
+function init_iterates(xinit::Image, b)
+    return init_rest(copy(xinit))
+end
+
+function init_iterates(xinit::Nothing, b :: Image)
+    return init_rest(zeros(size(b)...))
+end
+
+############
+# Algorithm
+############
+
+function denoise_pdps(b :: Image;
+                      xinit :: Union{Image,Nothing} = nothing,
+                      iterate = AlgTools.simple_iterate,
+                      params::NamedTuple) where DisplacementT
+
+    ################################                                        
+    # Extract and set up parameters
+    ################################                    
+
+    α, ρ = params.α, params.ρ
+    τ₀, σ₀ =  params.τ₀, params.σ₀
+
+    R_K = ∇₂_norm₂₂_est
+    γ = 1
+
+    @assert(τ₀*σ₀ < 1)
+    σ = σ₀/R_K
+    τ = τ₀/R_K
+    
+    ######################
+    # Initialise iterates
+    ######################
+
+    x, y, Δx, Δy, x̄ = init_iterates(xinit, b)
+
+    ####################
+    # Run the algorithm
+    ####################
+
+    v = iterate(params) do verbose :: Function
+        ω = params.accel ? 1/√(1+2*γ*τ) : 1
+        
+        ∇₂ᵀ!(Δx, y)                    # primal step:
+        @. x̄ = x                       # |  save old x for over-relax
+        @. x = (x-τ*(Δx-b))/(1+τ)      # |  prox
+        @. x̄ = (1+ω)*x - ω*x̄           # over-relax: x̄ = 2x-x_old
+        ∇₂!(Δy, x̄)                     # dual step: y
+        @. y = (y + σ*Δy)/(1 + σ*ρ/α)  # |
+        proj_norm₂₁ball!(y, α)         # |  prox
+
+        if params.accel
+            τ, σ = τ*ω, σ/ω
+        end
+                
+        ################################
+        # Give function value if needed
+        ################################
+        v = verbose() do            
+            ∇₂!(Δy, x)
+            value = norm₂²(b-x)/2 + params.α*γnorm₂₁(Δy, params.ρ)
+            value, x
+        end
+
+        v
+    end
+
+    return x, y, v
+end
+
+end # Module
+
+
--- a/src/ImageTools.jl	Tue Nov 19 10:15:38 2019 -0500
+++ b/src/ImageTools.jl	Thu Nov 21 18:44:27 2019 -0500
@@ -6,6 +6,8 @@
 
 include("Gradient.jl")
 include("Translate.jl")
+include("Denoise.jl")
+include("Visualise.jl")
 
 end # module
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Visualise.jl	Thu Nov 21 18:44:27 2019 -0500
@@ -0,0 +1,192 @@
+##################################################
+# Visualising and data-collecting iteration tools
+##################################################
+
+module Visualise
+
+using Printf
+using Distributed
+using FileIO
+using Setfield
+using Images, Plots, Measures
+
+using AlgTools.Util
+using AlgTools.StructTools
+using AlgTools.LinkedLists
+
+##############
+# Our exports
+##############
+
+export LogEntry,
+       bg_visualise,
+       visualise,
+       clip,
+       grayimg,
+       secs_ns,
+       iterate_visualise,
+       initialise_visualisation,
+       finalise_visualisation
+
+##################
+# Data structures
+##################
+
+struct LogEntry <: IterableStruct
+    iter :: Int
+    time :: Float64
+    function_value :: Float64
+end
+
+struct State
+    vis :: Union{Distributed.RemoteChannel,Bool,Nothing}
+    visproc :: Union{Nothing,Future}
+    start_time :: Union{Real,Nothing}
+    wasted_time :: Real
+    log :: LinkedList{LogEntry}
+end
+
+##################
+# Helper routines
+##################
+
+@inline function secs_ns()
+    return convert(Float64, time_ns())*1e-9
+end
+
+clip = x -> min(max(x, 0.0), 1.0)
+grayimg = im -> Gray.(clip.(im))
+
+################
+# Visualisation
+################
+
+function bg_visualise(rc)
+    while true
+        imgs=take!(rc)
+        # Take only the latest image to visualise
+        while isready(rc)
+            imgs=take!(rc)
+        end
+        # We're done if we were fed an empty image
+        if isnothing(imgs)
+            break
+        end
+        do_visualise(imgs)
+    end
+    return
+end
+
+function do_visualise(imgs)
+    plt = im -> plot(grayimg(im), showaxis=false, grid=false, aspect_ratio=:equal, margin=2mm)
+    display(plot([plt(imgs[i]) for i =1:length(imgs)]..., reuse=true, margin=0mm))
+end
+
+function visualise(rc_or_vis, imgs)
+    if isa(rc_or_vis, RemoteChannel)
+        rc = rc_or_vis
+        while isready(rc)
+            take!(rc)
+        end
+        put!(rc, imgs)
+    elseif isa(rc_or_vis, Bool) && rc_or_vis
+        do_visualise(imgs)
+    end
+end
+
+######################################################
+# Iterator that does visualisation and log collection
+######################################################
+
+function iterate_visualise(st :: State,
+                           step :: Function,
+                           params :: NamedTuple) where DisplacementT
+    try
+        for iter=1:params.maxiter 
+            st = step() do calc_objective
+                if isnothing(st.start_time)
+                    # The Julia precompiler is a miserable joke, apparently not crossing module
+                    # boundaries, so only start timing after the first iteration.
+                    st = @set st.start_time=secs_ns()
+                end
+
+                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0
+                    
+                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
+                    verb_start = secs_ns()
+                    tm = verb_start - st.start_time - st.wasted_time
+                    value, x = calc_objective()
+
+                    entry = LogEntry(iter, tm, value)
+
+                    # (**) Collect a singly-linked list of log to avoid array resizing
+                    # while iterating
+                    st = @set st.log=LinkedListEntry(entry, st.log)
+                    
+                    if verb
+                        @printf("%d/%d J=%f\n", iter, params.maxiter, value)
+                        visualise(st.vis, (x,))
+                    end
+
+                    if params.save_iterations
+                        fn = t -> "$(params.save_prefix)_$(t)_iter$(iter).png"
+                        save(File(format"PNG", fn("reco")), grayimg(x))
+                    end
+
+                    st = @set st.wasted_time += (secs_ns() - verb_start)
+                end
+                
+                return st
+            end
+        end
+    catch ex
+        if isa(ex, InterruptException)
+            # If SIGINT is received (user pressed ^C), terminate computations,
+            # returning current status. Effectively, we do not call `step()` again,
+            # ending the iterations, but letting the algorithm finish up.
+            # Assuming (**) above occurs atomically, `st.log` should be valid, but
+            # any results returned by the algorithm itself may be partial, as for
+            # reasons of efficiency we do *not* store results of an iteration until
+            # the next iteration is finished.
+            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
+        else
+            throw(ex)
+        end
+    end
+    
+    return st
+end
+
+####################
+# Launcher routines
+####################
+
+function initialise_visualisation(visualise; iterator=iterate_visualise)
+    # Create visualisation
+    if visualise
+        rc = RemoteChannel()
+        visproc = @spawn bg_visualise(rc)
+        vis =rc
+        #vis = true
+
+        sleep(0.01)
+    else
+        vis = false
+        visproc = nothing
+    end
+
+    st = State(vis, visproc, nothing, 0.0, nothing)
+    iterate = curry(iterate_visualise, st)
+
+    return st, iterate
+end
+
+function finalise_visualisation(st)
+    if isa(st.rc, RemoteChannel)
+        # Tell subprocess to finish, and wait
+        put!(st.rc, nothing)
+        wait(st.visproc)
+    end
+end
+
+end # Module
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/denoise.jl	Thu Nov 21 18:44:27 2019 -0500
@@ -0,0 +1,64 @@
+##################
+# Denoise testing
+##################
+
+__precompile__()
+
+using Printf
+using Images
+import TestImages
+
+using AlgTools.Util
+using AlgTools.LinkedLists
+using ImageTools.Denoise
+using ImageTools.Visualise
+
+const default_save_prefix="denoise_result_"
+
+const default_params = (
+    α = 1,
+    τ₀ = 5,
+    σ₀ = 0.99/5,
+    ρ = 0,
+    accel = true,
+    noise_level = 0.5,
+    verbose_iter = 10,
+    maxiter = 1000,
+    save_results = false,
+    save_iterations = false,
+    image_name = "lighthouse",
+)
+
+#######################
+# Main testing routine
+#######################
+
+function test_denoise(;
+                      visualise=true,
+                      save_prefix=default_save_prefix,
+                      kwargs...)
+
+    
+    # Parameters for this experiment
+    params = default_params ⬿ kwargs
+    params = params ⬿ (save_prefix = save_prefix * params.image_name,)
+
+    # Load image and add noise
+    b = Float64.(Gray.(TestImages.testimage(params.image_name)))
+    b_noisy = b .+ params.noise_level.*randn(size(b)...)
+
+    # Launch (background) visualiser
+    st, iterate = initialise_visualisation(visualise)
+
+    # Run algorithm
+    x, y, st = denoise_pdps(b_noisy; iterate=iterate, params=params)
+
+    if params.save_results
+        perffile = params.save_prefix * ".txt"
+        println("Saving " * perffile)
+        write_log(perffile, st.log, "# params = $(params)\n")
+    end
+
+    # Exit background visualiser
+    finish_visualisation(st)
+end

mercurial