Skip to content

Commit

Permalink
Add CNO CUDA wrapper (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella authored Mar 5, 2025
1 parent 32a6e3f commit a774a1f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The CNOs can then be used as custom Lux models and they are compatible with [clo

```julia
using Pkg
Pkg.add("git@github.com:DEEPDIP-project/ConvolutionalNeuralOperator.jl.git")
Pkg.add(url="git@github.com:DEEPDIP-project/ConvolutionalNeuralOperator.jl.git")
```

## Usage
Expand Down
2 changes: 1 addition & 1 deletion src/ConvolutionalNeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ ArrayType = CUDA.functional() ? CUDA.CuArray : Array
include("utils.jl")
include("models.jl")

export create_CNO, create_CNOdownsampler, create_CNOupsampler, create_CNOactivation
export create_CNO, create_CNOdownsampler, create_CNOupsampler, create_CNOactivation, cno

end
14 changes: 14 additions & 0 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,17 @@ end
function combined_mconv_activation_updown(y, k, mask, activation, updown)
updown(activation(apply_masked_convolution(y, k = k, mask = mask)))
end

function cno(kwargs...)
if use_cuda
dev = Lux.gpu_device()
else
dev = Lux.cpu_device()
end
filtered_kwargs = Dict(k => v for (k, v) in kwargs if k != :use_cuda && k != :rng)
model = create_CNO(; filtered_kwargs...)
params, state = Lux.setup(rng, model)
state = state |> dev
params = ComponentArray(params) |> dev
(model, params, state)
end
45 changes: 45 additions & 0 deletions test/test-training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,49 @@ using Test # Importing the Test module for @test statements
loss_final = loss(optim_result.u, 128)
@test loss_final < loss_0 # Ensure loss decreases after optimization

# Now test the CUDA wrapper
model, θ, st = cno(
T = T,
N = N,
D = D,
cutoff = cutoff,
ch_sizes = ch_,
activations = act,
down_factors = df,
k_radii = k_rad,
bottleneck_depths = bd,
rng = rng,
use_cuda = false,
)

@info "There are $(length(θ)) parameters"

@test typeof(model) <: Lux.Chain
@test size(model(u, θ, st)[1]) == size(u)
function loss(θ, batch = 16)
y = rand(T, N, N, 1, batch)
y = cat(y, y, dims = 3)
yout = model(y, θ, st)[1]
return sum(abs2, (yout .- y))
end
loss_0 = loss(θ, 128)
@test isfinite(loss_0) # Ensure initial loss is a finite number
g = Zygote.gradient-> loss(θ), θ)
@test !isnothing(g) # Ensure gradient is calculated successfully
function callback(p, l_train)
println("Training Loss: $(l_train)")
false
end
optf = Optimization.OptimizationFunction((p, _) -> loss(p), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, θ)
ClipAdam = OptimiserChain(Adam(1.0e-1), ClipGrad(1))
optim_result, optim_t, optim_mem, _ = @timed Optimization.solve(
optprob,
ClipAdam,
maxiters = 10,
callback = callback,
progress = true,
)
loss_final = loss(optim_result.u, 128)
@test loss_final < loss_0
end

0 comments on commit a774a1f

Please sign in to comment.