From 0f3488f1f1fe2653fe68151550074343dff0e5e2 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 26 Feb 2025 16:27:45 +0100 Subject: [PATCH] Update tests --- README.md | 45 ++++++++++++++++++++++++++++++++++++++++++++- src/layer.jl | 13 +++++++++---- test/test-cnn.jl | 3 +++ test/test-layer.jl | 3 +++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ede173c..26d1979 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,50 @@ Pkg.add(url="git@github.com:DEEPDIP-project/AttentionLayer.jl.git") ## Usage -Look in `test/` for examples on how to use the package. +You are probably interested in using the `attentioncnn` model, which is a built-in cnn that uses the attention mechanism. +Here is an example of how to use it: + +* first you have to define the parameters of the model + +```julia + T = Float32 # the type of the data + N = 16 # size of the input + D = 2 # number of channels + rng = Xoshiro(123) # random number generator + r = [2, 2] # radii of the attention mechanism + c = [4, 2] # number of features of the intermediate layers + σ = [tanh, identity] # activation functions + b = [true, false] # use bias + emb_sizes = [8, 8] # size of the embeddings + patch_sizes = [8, 5] # size of the patches in which the attention mechanism is applied + n_heads = [2, 2] # number of heads of the attention mechanism + use_attention = [true, true] # use the attention at this layer + sum_attention = [false, false] # use attention in sum mode instead of concat mode (BUG) +``` + +* then you can call the model + +```julia + closure, θ, st = attentioncnn( + T = T, + N = N, + D = D, + data_ch = D, + radii = r, + channels = c, + activations = σ, + use_bias = b, + use_attention = use_attention, + emb_sizes = emb_sizes, + patch_sizes = patch_sizes, + n_heads = n_heads, + sum_attention = sum_attention, + rng = rng, + use_cuda = false, + ) +``` + +Look in `test/` for more examples about how to use the package. ## How to Cite diff --git a/src/layer.jl b/src/layer.jl index b056c59..6d78790 100644 --- a/src/layer.jl +++ b/src/layer.jl @@ -70,10 +70,15 @@ end function Lux.parameterlength( (; N, d, n_heads, dh, emb_size, patch_size, n_patches)::attention, ) - 3 * n_heads * dh * (emb_size + 1) + - patch_size * patch_size * d * emb_size + - emb_size + - N * N * d * n_patches * n_heads * dh + size_wQ = n_heads * dh * (emb_size + 1) + size_wK = n_heads * dh * (emb_size + 1) + size_wV = n_heads * dh * (emb_size + 1) + size_Ew = emb_size * patch_size * patch_size * d + size_Eb = emb_size + size_U = N * N * d * n_patches * n_heads * dh + + total_size = size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U + return total_size end Lux.statelength(::attention) = 9 diff --git a/test/test-cnn.jl b/test/test-cnn.jl index d403504..b65e44d 100644 --- a/test/test-cnn.jl +++ b/test/test-cnn.jl @@ -60,4 +60,7 @@ using Zygote: Zygote grad = Zygote.gradient(θ -> sum(abs2, closure(input_tensor, θ, st)[1]), θ) @test !isnothing(grad) # Ensure gradients were successfully computed + y, back = Zygote.pullback(θ -> sum(abs2, closure(input_tensor, θ, st)[1]), θ) + @test y == sum(abs2, closure(input_tensor, θ, st)[1]) + end diff --git a/test/test-layer.jl b/test/test-layer.jl index 1d850ce..d4ec87b 100644 --- a/test/test-layer.jl +++ b/test/test-layer.jl @@ -69,4 +69,7 @@ using Zygote: Zygote grad = Zygote.gradient(θ -> sum(abs2, closure(input_tensor, θ, st)[1]), θ) @test !isnothing(grad) # Ensure gradients were successfully computed + y, back = Zygote.pullback(θ -> sum(abs2, closure(input_tensor, θ, st)[1]), θ) + @test y == sum(abs2, closure(input_tensor, θ, st)[1]) + end