Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the main architecture is MLP #3

Open
WangJYao opened this issue Jul 4, 2023 · 1 comment
Open

the main architecture is MLP #3

WangJYao opened this issue Jul 4, 2023 · 1 comment

Comments

@WangJYao
Copy link

WangJYao commented Jul 4, 2023

Hi,
Thanks for open-sourcing this repository!
If the main architecture is MLP,
how should I set the parameters of the HyperNet ?
or what should I pay attention to ?

@JJGO
Copy link
Owner

JJGO commented Jul 7, 2023

This is a minimal working example similar to the convnet in the documentation

import torch
from torch import nn, Tensor
import torch.nn.functional as F
import hyperlight as hl

class MLP(nn.Sequential):

    def __init__(self, in_features: int, out_features: int, hidden: list[int]):
        super().__init__()

        for n_in, n_out in zip([in_features]+hidden, hidden):
            self.append(nn.Linear(n_in, n_out))
            self.append(nn.LeakyReLU())
        out = nn.Linear(hidden[-1], out_features)
        self.append(out)
        
        

class HyperMLP(nn.Module):

    def __init__(self):
        super().__init__()
        mainnet = MLP(100, 10, [32, 64])
        modules = hl.find_modules_of_type(mainnet, [nn.Linear])

        self.mainnet = hl.hypernetize(mainnet, modules=modules)
        parameter_shapes = self.mainnet.external_shapes()
        self.hypernet = hl.HyperNet(
            input_shapes={'h': (10,)},
            output_shapes=parameter_shapes,
            hidden_sizes=[16,64,128],
        )

    def forward(self, main_input, hyper_input):
        parameters = self.hypernet(h=hyper_input)

        with self.mainnet.using_externals(parameters):
            prediction = self.mainnet(main_input)

        return prediction
    
x = torch.randn(7,100)
h = torch.randn(10)
model = HyperMLP()
print(model(x, h).shape)
# torch.Size([7, 10])

The main consideration is that hypernetizing MLPs you can end up with many parameters as there is no weight reuse, you might want to consider only hypernetizing some layers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants