Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exclude_from_weight_decay to AdamW #16274

Merged
merged 10 commits into from
Apr 5, 2022

Conversation

markub3327
Copy link
Contributor

@markub3327 markub3327 commented Mar 20, 2022

Hello,

@bhack

there are changes for applying exclude_from_weight_decay to AdamW (#16201).

Thanks.

@@ -173,7 +181,7 @@ def update_step(self, gradient, variable):
alpha = (lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power))

# Apply step weight decay
if self.weight_decay != 0:
if self.weight_decay and self._do_use_weight_decay(variable):
Copy link
Contributor Author

@markub3327 markub3327 Mar 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The better than used if self.weight_decay != 0: is if self.weight_decay: becouse user can set weight_decay=None. None is not a number (float) like 0, but it's empty object.

@gbaned gbaned requested a review from fchollet March 21, 2022 09:46
@google-ml-butler google-ml-butler bot added the keras-team-review-pending Pending review by a Keras team member. label Mar 21, 2022
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

No variable name should ever be used in the public API surface, because variable names are not stable.

Instead, this configuration parameter should be an actual list of variables (e.g. [layer.bias for layer in my_dense_layers]), which will be checked by identity.

Such a configuration parameter cannot be serialized and is specific to the set of weights that the optimizer will train, which is not defined at optimizer creation time, but only in build(). As such it cannot be a constructor argument. Instead, this should be configured via a method.

Let's add it to build():

def build(self, var_list, exclude_from_weight_decay=None)

Then you can do:

model = ...
optimizer = AdamW()
optimizer.build(model.trainable_weights, exclude_from_weight_decay=get_bias_terms(model))

We could also have a separate method:

optimizer.exclude_from_weight_decay(var_list)

Does that make sense?

@markub3327
Copy link
Contributor Author

@fchollet

Yes. I'll make a changes. Names of variables can be changed dynamicaly, but variables have static pointers during existing Model object....

Thanks.

@markub3327 markub3327 requested a review from fchollet March 22, 2022 05:39
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Please add a correctness test for the change.

@@ -136,6 +136,7 @@ def build(self, var_list):
if hasattr(self, '_built') and self._built:
return
self._built = True
self._exclude_from_weight_decay = exclude_from_weight_decay
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exclude_from_weight_decay or [] (need to support the default None case)

Copy link
Contributor Author

@markub3327 markub3327 Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future you can define it as exclude_from_weight_decay: Optional[List[tf.Variable]] = None and I added the None case solution.

@@ -215,6 +216,8 @@ def get_config(self):
})
return config

def exclude_from_weight_decay(self, var_list):
self._exclude_from_weight_decay = var_list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check that the optimizer has not already been built, and raise an error otherwise. Otherwise calling this might not have any effect.

Also need to cast to list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok... You have a true. I created a raise an error.

Thanks.

@fchollet fchollet removed the keras-team-review-pending Pending review by a Keras team member. label Mar 22, 2022
@markub3327 markub3327 requested a review from fchollet March 22, 2022 19:48
@google-ml-butler google-ml-butler bot added the keras-team-review-pending Pending review by a Keras team member. label Mar 22, 2022
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

if var_list == None:
self._exclude_from_weight_decay = []
else:
self._exclude_from_weight_decay = var_list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can replace these 4 lines with

self._exclude_from_weight_decay = var_list or []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

else:
self._exclude_from_weight_decay = var_list
else:
raise ValueError('This optimizer has not yet been built. '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check should work the other way around: the user need to call exclude_from_weight_decay before the optimizer has been built.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry.

@markub3327 markub3327 requested a review from fchollet March 24, 2022 20:13
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 25, 2022
@markub3327 markub3327 requested a review from fchollet March 25, 2022 05:48
@@ -173,7 +174,7 @@ def update_step(self, gradient, variable):
alpha = (lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power))

# Apply step weight decay
if self.weight_decay != 0:
if self.weight_decay and variable not in self._exclude_from_weight_decay:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.weight_decay can not be used as a pythonic bool. This raises the error

OperatorNotAllowedInGraphError: Using a symbolic tf.Tensor as a Python bool is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature

self.weight_decay != 0 could be a solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, please add a correctness test -- this sort of issue should be caught in unit tests.

@fchollet
Copy link
Collaborator

Can you share the code you're using to print the variable values?

@fchollet
Copy link
Collaborator

fchollet commented Apr 1, 2022

We should decide whether we want to go with equation A or B:

Equation A: variable.assign_sub(variable * (1 - lr * wd))

Equation B: variable.assign_sub(variable * wd)

What are the arguments for each?

Apparently A is what PyTorch uses (and our current implementation). B is what the paper specifies and also what TF Addons uses. @markub3327 you seem to be saying that B is more intuitive.

This code was authored by @yarri-oss -- Douglas, what do you think?

@yarri-oss
Copy link

@markub3327 thank you again for your feedback. As @fchollet mentioned, we've seen arguments for each of the above equations. A) seems to be in use with other frameworks, and works with the assumption of smaller weight_decay on the same order as the learning_rate. B) is used with TF Addons, which pre-dates this Keras AdamW implementation and, as you say, is more intuitive with larger (ie., 10x ~ 100x the lr) values of weight_decay.

Are you experienced with the TF Addons AdamW implementation? Is that why you prefer the larger values of weight_decay?

Thanks!

@markub3327
Copy link
Contributor Author

markub3327 commented Apr 1, 2022

@yarri-oss Yes, I'm familiar with the original paper and implementation from TF Addons. The option A) is slightly different thing. With the current implementation I must during hyperparameter tuning finding to higher weight decay values for less impact of weight decay on trainable variables (I prefer smaller weight decay for smaller impact). Now is the impact of weight decay more aggressively and quickly takes it to zero.

For example weight_decay=1.0 still has a big impact on the variable becouse lr is 0.001. On the other side with too big lr the impact of weight decay will be smaller but model can has a problem with finding the global minima.

@fchollet I get it by changing the unit test to show me a result:
Scenario 1

def testWeightDecay(self):
    grads, var1, var2, var3 = tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0), tf.Variable(2.0)
    optimizer_1 = adamw_new.AdamW(learning_rate=0.001, weight_decay=1.0)
    optimizer_1.apply_gradients(zip([grads], [var1]))
    print(var1)

    optimizer_2 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_2.exclude_from_weight_decay([var2])
    optimizer_2.apply_gradients(zip([grads], [var2]))

    optimizer_3 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_3.build([var3], exclude_from_weight_decay=[var3])
    optimizer_3.apply_gradients(zip([grads], [var3]))

    self.assertEqual(var2, 0.0)

Scenario 2

def testWeightDecay(self):
    grads, var1, var2, var3 = tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0), tf.Variable(2.0)
    optimizer_1 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.001)
    optimizer_1.apply_gradients(zip([grads], [var1]))
    print(var1)

    optimizer_2 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_2.exclude_from_weight_decay([var2])
    optimizer_2.apply_gradients(zip([grads], [var2]))

    optimizer_3 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_3.build([var3], exclude_from_weight_decay=[var3])
    optimizer_3.apply_gradients(zip([grads], [var3]))

    self.assertEqual(var2, 0.0)

Thanks!

@bhack
Copy link
Contributor

bhack commented Apr 1, 2022

@bhack
Copy link
Contributor

bhack commented Apr 1, 2022

See also the author of the paper (Q2) loshchil/AdamW-and-SGDW#1 (comment)

@chenmoneygithub
Copy link
Contributor

@markub3327 This really depends on how we define decay. We are hitting this issue because the paper uses decay in an opposite way of learning rate decay. In this case, I would prefer sticking to the paper, which is larger weight decay means larger impact on model variables.

@markub3327
Copy link
Contributor Author

@chenmoneygithub

Okey. Did You say the preferred is option B): variable.assign_sub(variable * wd)?

I can make a change in the code.

@fchollet
Copy link
Collaborator

fchollet commented Apr 2, 2022

In light of the evidence so far I personally find option B variable.assign_sub(variable * wd) more convincing (especially given that the general behavior is more consistent with how we define weight decay or learning rate decay elsewhere in the API). Is this the consensus view?

@chenmoneygithub
Copy link
Contributor

chenmoneygithub commented Apr 2, 2022 via email

@markub3327
Copy link
Contributor Author

markub3327 commented Apr 3, 2022

@fchollet
All is done.

@keras-team keras-team deleted a comment from namas191297 Apr 3, 2022
@fchollet
Copy link
Collaborator

fchollet commented Apr 3, 2022

@yarri-oss does the change sound good?

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 3, 2022
@chenmoneygithub
Copy link
Contributor

The implementation looks correct to me.

@yarri-oss
Copy link

@yarri-oss does the change sound good?

Thanks for the PR, glad to see consensus on this topic

@fchollet
Copy link
Collaborator

fchollet commented Apr 5, 2022

Re-running tests. Previous run appeared to be flaky.

@copybara-service copybara-service bot merged commit 584cb2a into keras-team:master Apr 5, 2022
@edwardyehuang
Copy link
Contributor

Any chance to include this in 2.9 rc0? This is a very important feature.
@fchollet @qlzh727

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull Ready to be merged into the codebase size:S
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants