src/AlgorithmPET.jl

changeset 8
e4ad8f7ce671
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/AlgorithmPET.jl	Fri Apr 19 17:00:37 2024 +0300
@@ -0,0 +1,242 @@
+####################################################################
+# Predictive online PDPS for optical flow with known velocity field
+####################################################################
+
+__precompile__()
+
+module AlgorithmPET
+
+identifier = "pet_known_orig"
+
+using Printf
+
+using AlgTools.Util
+import AlgTools.Iterate
+using ImageTools.Gradient
+using ImageTools.Translate
+  
+using ..Radon   
+using ImageTransformations
+using Images, CoordinateTransformations, Rotations, OffsetArrays
+using ImageCore, Interpolations
+
+using ..OpticalFlow: ImageSize,
+                     Image,
+                     petpdflow!
+
+#########################
+# Iterate initialisation
+#########################
+
+
+
+function init_rest(x::Image)
+    imdim=size(x)
+
+    y = zeros(2, imdim...)
+    Δx = copy(x)
+    Δy = copy(y)
+    x̄ = copy(x)
+    radonx = copy(x)
+
+    return x, y, Δx, Δy, x̄, radonx
+end
+
+function init_iterates(xinit::Image)
+    return init_rest(copy(xinit))
+end
+
+function init_iterates(dim::ImageSize)
+    return init_rest(zeros(dim...))
+end
+
+#########################
+# PETscan related
+#########################
+function petvalue(x, b, c)
+    tmp = similar(b)
+    radon!(tmp, x)
+    return sum(@. tmp - b*log(tmp+c))
+end
+
+function petgrad!(res, x, b, c, S)
+    tmp = similar(b)
+    radon!(tmp, x)
+    @. tmp = S .- b/(tmp+c)
+    backproject!(res, S.*tmp)
+end
+
+function proj_nonneg!(y)
+    @inbounds @simd for i=1:length(y)
+        if y[i] < 0
+            y[i] = 0
+        end
+    end
+    return y
+end
+
+
+
+############
+# Algorithm
+############
+
+function step_lengths(params, γ, R_K²)
+    ρ̃₀, τ₀, σ₀, σ̃₀ =  params.ρ̃₀, params.τ₀, params.σ₀, params.σ̃₀
+    δ = params.δ
+    ρ = isdefined(params, :phantom_ρ) ? params.phantom_ρ : params.ρ
+    Λ = params.Λ
+    Θ = params.dual_flow ? Λ : 1
+
+    τ = τ₀/γ
+    @assert(1+γ*τ ≥ Λ)
+    σ = σ₀*1/(τ*R_K²)
+    #σ = σ₀*min(1/(τ*R_K²), 1/max(0, τ*R_K²/((1+γ*τ-Λ)*(1-δ))-ρ))
+    q = δ*(1+σ*ρ)/Θ
+    if 1 ≥ q
+        σ̃ = σ̃₀*σ/q
+        #ρ̃ = ρ̃₀*max(0, ((Θ*σ)/(2*δ*σ̃^2*(1+σ*ρ))+1/(2σ)-1/σ̃))
+        ρ̃ = max(0, (1-q)/(2*σ))
+    else
+        σ̃ = σ̃₀*σ/(q*(1-√(1-1/q)))
+        ρ̃ = 0
+    end
+    
+    #println("Step length parameters: τ=$(τ), σ=$(σ), σ̃=$(σ̃), ρ̃=$(ρ̃)")
+
+    return τ, σ, σ̃, ρ̃
+end
+
+function solve( :: Type{DisplacementT};
+               dim :: ImageSize,
+               iterate = AlgTools.simple_iterate,
+               params::NamedTuple) where DisplacementT
+
+    ################################                                        
+    # Extract and set up parameters
+    ################################                    
+
+    α, ρ = params.α, params.ρ
+    R_K² = ∇₂_norm₂₂_est²
+    γ = 1
+    # τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)
+    λ = params.λ
+    ω = 1
+    c = params.c*ones(params.radondims...)
+
+    ρ̃₀, τ₀, σ₀, σ̃₀ =  params.ρ̃₀, params.τ₀, params.σ₀, params.σ̃₀
+
+    # Update step length parameters                    
+    L = 300.0
+    τ = τ₀/L
+    σ = σ₀*(1-τ₀)/(R_K²*τ)
+    println("Step length parameters: L=$(round(L, digits=4)), τ=$(round(τ, digits=4)), σ=$(round(σ, digits=4))")
+
+
+
+    δ = params.δ
+    ρ = isdefined(params, :phantom_ρ) ? params.phantom_ρ : params.ρ
+    Λ = params.Λ
+    Θ = params.dual_flow ? Λ : 1
+
+    q = δ*(1+σ*ρ)/Θ
+    if 1 ≥ q
+        σ̃ = σ̃₀*σ/q
+        #ρ̃ = ρ̃₀*max(0, ((Θ*σ)/(2*δ*σ̃^2*(1+σ*ρ))+1/(2σ)-1/σ̃))
+        ρ̃ = max(0, (1-q)/(2*σ))
+    else
+        σ̃ = σ̃₀*σ/(q*(1-√(1-1/q)))
+        ρ̃ = 0
+    end
+
+    ######################
+    # Initialise iterates
+    ######################
+
+    x, y, Δx, Δy, x̄, r∇ = init_iterates(dim)
+    
+    # L = 1.0
+    # oldpetgradx = zeros(size(x)...)
+    # petgradx = zeros(size(x))
+    # oldx = ones(size(x))
+
+    ####################
+    # Run the algorithm
+    ####################
+                        # THIS IS THE step function inside iterate_visualise
+    v = iterate(params) do verbose :: Function,
+                           b :: Image,                   # noisy_sinogram
+                           v_known :: DisplacementT,
+                           theta_known :: DisplacementT,
+                           b_true :: Image,
+                           S :: Image    
+        
+        ##################################
+        # Update the step length parameter
+        ##################################
+        # τ = τ₀/L
+        # σ = σ₀*(1-τ₀)/(R_K²*τ)
+        # println("Step length parameters: L=$(round(L, digits=4)), τ=$(round(τ, digits=4)), σ=$(round(σ, digits=4))") 
+
+
+        ###################    
+        # Prediction steps
+        ###################
+
+        petpdflow!(x, Δx, y, Δy, v_known, theta_known, params.dual_flow)                      # Old algorithm
+        #pdflow!(x, Δx, y, Δy, v_known, theta_known, params.dual_flow, 1e-2,1e-2)           # Rotation
+        #pdflow!(x, Δx, y, Δy, v_known, theta_known, params.dual_flow, 1e-2)                # Adhoc
+        #@. oldx = x
+
+        if params.prox_predict
+            ∇₂!(Δy, x)
+            @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α))
+            #@. cc = y + 1000000*σ̃*Δy 
+            #@. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α)) + (1 - 1/(1 + ρ̃*σ̃))*cc
+            proj_norm₂₁ball!(y, α) 
+        end
+
+
+        ############
+        # PDPS step
+        ############
+
+        ∇₂ᵀ!(Δx, y)                    # primal step:
+        @. x̄ = x                       # | save old x for over-relax
+        petgrad!(r∇, x, b, c, S)          # | Calculate gradient of fidelity term
+
+        @. x = x-(τ*λ)*r∇-τ*Δx         # |
+        proj_nonneg!(x)                # | non-negativity constaint prox
+        @. x̄ = (1+ω)*x - ω*x̄           # over-relax: x̄ = 2x-x_old
+        ∇₂!(Δy, x̄)                     # dual step:
+        @. y = y + σ*Δy                # |
+        proj_norm₂₁ball!(y, α)         # |  prox
+
+        ##########################################
+        # Compute for the local Lipschitz constant
+        ##########################################
+        # petgrad!(petgradx, x, b, c, S)
+        # petgrad!(oldpetgradx, oldx, b, c, S)
+        # if norm₂(x-oldx)>1e-12
+        #    L = max(0.9*norm₂(petgradx - oldpetgradx)/norm₂(x-oldx),L)
+        # end   
+       
+        ################################
+        # Give function value if needed
+        ################################
+        
+        v = verbose() do            
+            ∇₂!(Δy, x)
+            value = λ*petvalue(x, b, c) + params.α*norm₂₁(Δy)
+            value, x, [NaN, NaN], nothing
+        end 
+        
+        v
+    end
+
+    return x, y, v
+end
+
+end # Module
+
+

mercurial