-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add exclude_from_weight_decay
to AdamW
#16274
Conversation
@@ -173,7 +181,7 @@ def update_step(self, gradient, variable): | |||
alpha = (lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) | |||
|
|||
# Apply step weight decay | |||
if self.weight_decay != 0: | |||
if self.weight_decay and self._do_use_weight_decay(variable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The better than used if self.weight_decay != 0:
is if self.weight_decay:
becouse user can set weight_decay=None. None is not a number (float) like 0, but it's empty object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
No variable name should ever be used in the public API surface, because variable names are not stable.
Instead, this configuration parameter should be an actual list of variables (e.g. [layer.bias for layer in my_dense_layers]
), which will be checked by identity.
Such a configuration parameter cannot be serialized and is specific to the set of weights that the optimizer will train, which is not defined at optimizer creation time, but only in build()
. As such it cannot be a constructor argument. Instead, this should be configured via a method.
Let's add it to build()
:
def build(self, var_list, exclude_from_weight_decay=None)
Then you can do:
model = ...
optimizer = AdamW()
optimizer.build(model.trainable_weights, exclude_from_weight_decay=get_bias_terms(model))
We could also have a separate method:
optimizer.exclude_from_weight_decay(var_list)
Does that make sense?
Yes. I'll make a changes. Names of variables can be changed dynamicaly, but variables have static pointers during existing Model object.... Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! Please add a correctness test for the change.
@@ -136,6 +136,7 @@ def build(self, var_list): | |||
if hasattr(self, '_built') and self._built: | |||
return | |||
self._built = True | |||
self._exclude_from_weight_decay = exclude_from_weight_decay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exclude_from_weight_decay or []
(need to support the default None
case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the future you can define it as exclude_from_weight_decay: Optional[List[tf.Variable]] = None
and I added the None case solution.
@@ -215,6 +216,8 @@ def get_config(self): | |||
}) | |||
return config | |||
|
|||
def exclude_from_weight_decay(self, var_list): | |||
self._exclude_from_weight_decay = var_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check that the optimizer has not already been built, and raise an error otherwise. Otherwise calling this might not have any effect.
Also need to cast to list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok... You have a true. I created a raise an error.
Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
if var_list == None: | ||
self._exclude_from_weight_decay = [] | ||
else: | ||
self._exclude_from_weight_decay = var_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can replace these 4 lines with
self._exclude_from_weight_decay = var_list or []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
else: | ||
self._exclude_from_weight_decay = var_list | ||
else: | ||
raise ValueError('This optimizer has not yet been built. ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check should work the other way around: the user need to call exclude_from_weight_decay
before the optimizer has been built.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@@ -173,7 +174,7 @@ def update_step(self, gradient, variable): | |||
alpha = (lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) | |||
|
|||
# Apply step weight decay | |||
if self.weight_decay != 0: | |||
if self.weight_decay and variable not in self._exclude_from_weight_decay: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.weight_decay can not be used as a pythonic bool. This raises the error
OperatorNotAllowedInGraphError: Using a symbolic
tf.Tensor
as a Pythonbool
is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature
self.weight_decay != 0
could be a solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, please add a correctness test -- this sort of issue should be caught in unit tests.
Can you share the code you're using to print the variable values? |
We should decide whether we want to go with equation A or B: Equation A: Equation B: What are the arguments for each? Apparently A is what PyTorch uses (and our current implementation). B is what the paper specifies and also what TF Addons uses. @markub3327 you seem to be saying that B is more intuitive. This code was authored by @yarri-oss -- Douglas, what do you think? |
@markub3327 thank you again for your feedback. As @fchollet mentioned, we've seen arguments for each of the above equations. A) seems to be in use with other frameworks, and works with the assumption of smaller Are you experienced with the TF Addons AdamW implementation? Is that why you prefer the larger values of Thanks! |
@yarri-oss Yes, I'm familiar with the original paper and implementation from TF Addons. The option A) is slightly different thing. With the current implementation I must during hyperparameter tuning finding to higher weight decay values for less impact of weight decay on trainable variables (I prefer smaller weight decay for smaller impact). Now is the impact of weight decay more aggressively and quickly takes it to zero. For example @fchollet I get it by changing the unit test to show me a result: def testWeightDecay(self):
grads, var1, var2, var3 = tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0), tf.Variable(2.0)
optimizer_1 = adamw_new.AdamW(learning_rate=0.001, weight_decay=1.0)
optimizer_1.apply_gradients(zip([grads], [var1]))
print(var1)
optimizer_2 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay([var2])
optimizer_2.apply_gradients(zip([grads], [var2]))
optimizer_3 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
optimizer_3.build([var3], exclude_from_weight_decay=[var3])
optimizer_3.apply_gradients(zip([grads], [var3]))
self.assertEqual(var2, 0.0) Scenario 2 def testWeightDecay(self):
grads, var1, var2, var3 = tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0), tf.Variable(2.0)
optimizer_1 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.001)
optimizer_1.apply_gradients(zip([grads], [var1]))
print(var1)
optimizer_2 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay([var2])
optimizer_2.apply_gradients(zip([grads], [var2]))
optimizer_3 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
optimizer_3.build([var3], exclude_from_weight_decay=[var3])
optimizer_3.apply_gradients(zip([grads], [var3]))
self.assertEqual(var2, 0.0) Thanks! |
See also the author of the paper (Q2) loshchil/AdamW-and-SGDW#1 (comment) |
@markub3327 This really depends on how we define |
Okey. Did You say the preferred is option B): I can make a change in the code. |
In light of the evidence so far I personally find option B |
Yea, following option B sounds good to me.
…On Sat, Apr 2, 2022 at 09:33 François Chollet ***@***.***> wrote:
In light of the evidence so far I personally find option B variable.assign_sub(variable
* wd) more convincing (especially given that the general behavior is more
consistent with how we define weight decay or learning rate decay elsewhere
in the API). Is this the consensus view?
—
Reply to this email directly, view it on GitHub
<#16274 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFO45Z7E5K4ICS3GHYMCHM3VDBZF5ANCNFSM5RFERZPQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@fchollet |
@yarri-oss does the change sound good? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
The implementation looks correct to me. |
Thanks for the PR, glad to see consensus on this topic |
Re-running tests. Previous run appeared to be flaky. |
Hello,
@bhack
there are changes for applying
exclude_from_weight_decay
to AdamW (#16201).Thanks.