##################
# Our main module
##################

__precompile__()

module PredictPDPS

########################
# Load external modules
########################

using ImageTools.ImFilter: gaussian
using AlgTools.Util

#####################
# Load local modules
#####################

include("OpticalFlow.jl")
include("Radon.jl")
include("ImGenerate.jl")
include("Run.jl")
include("AlgorithmProximal.jl")
include("AlgorithmBothMulti.jl")
include("AlgorithmFB.jl")
include("AlgorithmFBDual.jl")
include("AlgorithmNew.jl")
include("Stats.jl")
#include("PlotResults.jl")
include("PET/PET.jl")


import .AlgorithmBothMulti,
       .AlgorithmFB,
       .AlgorithmFBDual,
       .AlgorithmProximal,
       .AlgorithmNew

using .ImGenerate
using .OpticalFlow
using .Stats
#using .PlotResults
using .PET
using .Run

##############
# Our exports
##############

export run_experiments,
       batchrun_article,
       demo_known1, demo_known2, demo_known3,
       demo_unknown1,demo_unknown2,demo_unknown3,
       batchrun_denoising,
       batchrun_predictors,
       demo_denoising1, demo_denoising2, demo_denoising3,
       demo_denoising4, demo_denoising5, demo_denoising6, demo_denoising7,
       demo_petS1, demo_petS2, demo_petS3,
       demo_petS4, demo_petS5, demo_petS6, demo_petS7,
       demo_petB1, demo_petB2, demo_petB3,
       demo_petB4, demo_petB5, demo_petB6, demo_petB7,
       batchrun_shepplogan, batchrun_brainphantom, batchrun_pet,
       calculate_statistics
       #plot_denoising, plot_pet,
       
###################################
# Parameterisation and experiments
###################################

const default_save_prefix="img/"

const default_params = (
    ρ = 0,
    verbose_iter = 100,
    maxiter = 10000,
    save_results = true,
    save_images = true,
    save_images_iters = Set([1, 2, 3, 5,
                             10, 25, 30, 50,
                             100, 250, 300, 500,
                             1000, 2000, 2500, 3000, 4000, 5000,
                             6000, 7000, 7500, 8000, 9000, 10000, 8700]),
    pixelwise_displacement=false,
    dual_flow = true, # For AlgorithmProximalfrom 2019 paper
    handle_interrupt = true,
    init = :zero,
    plot_movement = false,
    stable_interval = Set(0),
)

const square = imgen_square((200, 300))
const lighthouse = imgen_shake("lighthouse", (200, 300))

const p_known₀ = default_params ⬿ (
    noise_level = 0.5,
    shake_noise_level = 0.05,
    shake = 2,
    α = 0.15,
    ρ̃₀ = 1,
    σ̃₀ = 1,
    δ = 0.9,
    σ₀ = 1,
    τ₀ = 0.01,
)

# Experiments for 2019 paper

const p_unknown₀ = default_params ⬿ (
    noise_level = 0.3,
    shake_noise_level = 0.05,
    shake = 2,
    α = 0.2,
    ρ̃₀ = 1,
    σ̃₀ = 1,
    σ₀ = 1,
    δ = 0.9,
    λ = 1,
    θ = (300*200)*100^3,
    kernel = gaussian((3, 3), (11, 11)),
    timestep = 0.5,
    displacement_count = 100,
    τ₀ = 0.01,
)

const experiments_pdps_known = (
    Experiment(AlgorithmProximal, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (phantom_ρ = 0,)),
    Experiment(AlgorithmProximal, DisplacementConstant, lighthouse,
               p_known₀ ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmProximal, DisplacementConstant, square,
               p_known₀ ⬿ (phantom_ρ = 0,))
)

const experiments_pdps_unknown_multi = (
    Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
               p_unknown₀ ⬿ (phantom_ρ = 0,)),
    Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
               p_unknown₀ ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmBothMulti, DisplacementConstant, square,
               p_unknown₀ ⬿ (phantom_ρ = 0,)),
)

const experiments_fb_known = (
    Experiment(AlgorithmFB, DisplacementConstant, lighthouse,
               p_known₀ ⬿ (τ̃₀=0.9, fb_inner_iterations = 10)),
)

const experiments_all = Iterators.flatten((
    experiments_pdps_known,
    experiments_pdps_unknown_multi,
    experiments_fb_known
))

# Image stabilisation experiments for 2024 paper. PET experiments are in PET/PET.jl

const p_known₀_denoising = default_params ⬿ (
    noise_level = 0.5,
    shake_noise_level = 0.025,
    shake = 2.0,
    α = 0.25,
    ρ̃₀ = 1.0,
    σ̃₀ = 1.0,
    δ = 0.9,
    σ₀ = 1.0,
    τ₀ = 0.01,
    #stable_interval = Set(0),
    stable_interval = union(Set(2500:5000),Set(8700:10000)),
)

const denoising_experiments_pdps_known = (
    Experiment(AlgorithmNew, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (predictor=DualScaling(),)),
    Experiment(AlgorithmNew, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (predictor=Greedy(),)),
    Experiment(AlgorithmNew, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (predictor=nothing,),),
    Experiment(AlgorithmNew, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (predictor=PrimalOnly(),)),
    Experiment(AlgorithmProximal, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmNew, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (predictor=Rotation(),)),
    Experiment(AlgorithmNew, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (predictor=ZeroDual(),)),
)

const denoising_experiments_all = Iterators.flatten((
    denoising_experiments_pdps_known,
))

#######################
# Demos and batch runs
#######################

function demo(experiment; kwargs...)
    run_experiments(;experiments=(experiment,),
                    save_results=false,
                    save_images=false,
                    save_prefix=default_save_prefix,
                    visualise=true,
                    recalculate=true,
                    verbose_iter=50,
                    fullscreen=true,
                    kwargs...)
end

demo_known1 = () -> demo(experiments_pdps_known[3])
demo_known2 = () -> demo(experiments_pdps_known[1])
demo_known3 = () -> demo(experiments_pdps_known[2])

demo_unknown1 = () -> demo(experiments_pdps_unknown_multi[3], plot_movement=true)
demo_unknown2 = () -> demo(experiments_pdps_unknown_multi[1], plot_movement=true)
demo_unknown3 = () -> demo(experiments_pdps_unknown_multi[2], plot_movement=true)

demo_denoising1 = () -> demo(denoising_experiments_pdps_known[1]) # Dual scaling
demo_denoising2 = () -> demo(denoising_experiments_pdps_known[2]) # Greedy
demo_denoising3 = () -> demo(denoising_experiments_pdps_known[3]) # No Prediction
demo_denoising4 = () -> demo(denoising_experiments_pdps_known[4]) # Primal Only
demo_denoising5 = () -> demo(denoising_experiments_pdps_known[5]) # Proximal (old)
demo_denoising6 = () -> demo(denoising_experiments_pdps_known[6]) # Rotation
demo_denoising7 = () -> demo(denoising_experiments_pdps_known[7]) # Zero dual

function batchrun_article(kwargs...)
    run_experiments(;experiments=experiments_all,
                    save_prefix=default_save_prefix,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end

function batchrun_denoising(;kwargs...)
    run_experiments(;experiments=denoising_experiments_all,
                    save_prefix=default_save_prefix,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end


function batchrun_predictors(;kwargs...)
    batchrun_denoising(;kwargs...)
    batchrun_pet(;kwargs...)
end

#########################
# Plotting SSIM and PSNR
#########################

#function plot_denoising(kwargs...)
#    ssim_plot("lighthouse")
#    psnr_plot("lighthouse")
#    fv_plot("lighthouse")
#end

end # Module
