-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a guide for writing a simple recurrent network #511
Comments
Staring some more into the Axon code I noticed that |
Does this example help? https://github.com/elixir-nx/axon/blob/main/examples/generative/text_generator.exs |
@polvalente thanks for a quick reply! Right now I figured out maybe i should avoid the Axon's In general, I am not sure if Axon intends to provide a re-usable building block (like |
I can't speak too much about the intention behind the design, but since we don't have an explicit Some things I concluded (pending any corrections by @seanmor5) that might help you:
So if you have an input which is of the shape Also you want to receive {cell, hidden} = carry
{wii, wif, wig, wio} = input_kernel
{whi, whf, whg, who} = hidden_kernel
{bi, bf, bg, bo} = bias
i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0))
f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0))
g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0))
o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0))
new_c = f * cell + i * g
new_h = o * activation_fn.(new_c)
|
Thanks for some clarifications. I am trying to go forward without using the I work on an example livebook where I rewrite a pytorch example to Axon. |
There is a new API input = Axon.input("seq")
# pad token is 0
mask = Axon.mask(input, 0)
embed = Axon.embedding(input, ...)
{seq, state} = Axon.lstm(embed, 32, mask: mask) |
Thanks! I would like to reiterate that an example of a custom RNN using all these features (unroll, masking, how to implement a "cell", can we call other layers from a RNN "cell") would be awesome to see in Axon guide! |
Hi!
Axon beginner here.
I struggle to figure out how to write a very simple RNN network. Basically I want to rewrite this pytorch example from a tutorial.
However, the Axon API makes it a bit convoluted to create networks that "scan" or "unroll" the input. After some digging I realized that I need to create something similar to
lstm_cell
andlstm
, but these APIs are not well documented (what dodynamic_unroll
arguments mean?). I am also not sure how to handle parameters in that case so the training mechanism (Axon.Loop.trainer
with standard optimizer and loss functions) can do it's job.The text was updated successfully, but these errors were encountered: