Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Feb 26, 2025
1 parent c9cbc30 commit 0f3488f
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,50 @@ Pkg.add(url="git@github.com:DEEPDIP-project/AttentionLayer.jl.git")

## Usage

Look in `test/` for examples on how to use the package.
You are probably interested in using the `attentioncnn` model, which is a built-in cnn that uses the attention mechanism.
Here is an example of how to use it:

* first you have to define the parameters of the model

```julia
T = Float32 # the type of the data
N = 16 # size of the input
D = 2 # number of channels
rng = Xoshiro(123) # random number generator
r = [2, 2] # radii of the attention mechanism
c = [4, 2] # number of features of the intermediate layers
σ = [tanh, identity] # activation functions
b = [true, false] # use bias
emb_sizes = [8, 8] # size of the embeddings
patch_sizes = [8, 5] # size of the patches in which the attention mechanism is applied
n_heads = [2, 2] # number of heads of the attention mechanism
use_attention = [true, true] # use the attention at this layer
sum_attention = [false, false] # use attention in sum mode instead of concat mode (BUG)
```

* then you can call the model

```julia
closure, θ, st = attentioncnn(
T = T,
N = N,
D = D,
data_ch = D,
radii = r,
channels = c,
activations = σ,
use_bias = b,
use_attention = use_attention,
emb_sizes = emb_sizes,
patch_sizes = patch_sizes,
n_heads = n_heads,
sum_attention = sum_attention,
rng = rng,
use_cuda = false,
)
```

Look in `test/` for more examples about how to use the package.

## How to Cite

Expand Down
13 changes: 9 additions & 4 deletions src/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ end
function Lux.parameterlength(
(; N, d, n_heads, dh, emb_size, patch_size, n_patches)::attention,
)
3 * n_heads * dh * (emb_size + 1) +
patch_size * patch_size * d * emb_size +
emb_size +
N * N * d * n_patches * n_heads * dh
size_wQ = n_heads * dh * (emb_size + 1)
size_wK = n_heads * dh * (emb_size + 1)
size_wV = n_heads * dh * (emb_size + 1)
size_Ew = emb_size * patch_size * patch_size * d
size_Eb = emb_size
size_U = N * N * d * n_patches * n_heads * dh

total_size = size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U
return total_size
end
Lux.statelength(::attention) = 9

Expand Down
3 changes: 3 additions & 0 deletions test/test-cnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,7 @@ using Zygote: Zygote
grad = Zygote.gradient-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
@test !isnothing(grad) # Ensure gradients were successfully computed

y, back = Zygote.pullback-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
@test y == sum(abs2, closure(input_tensor, θ, st)[1])

end
3 changes: 3 additions & 0 deletions test/test-layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,7 @@ using Zygote: Zygote
grad = Zygote.gradient-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
@test !isnothing(grad) # Ensure gradients were successfully computed

y, back = Zygote.pullback-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
@test y == sum(abs2, closure(input_tensor, θ, st)[1])

end

0 comments on commit 0f3488f

Please sign in to comment.