-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathautoencoder.py
42 lines (36 loc) · 1.34 KB
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# -*- coding: utf-8 -*-
"""
@project: lessr-master-force
@author: daijiuqian
@file: autoencoder.py
@ide: PyCharm
@time: 2022-02-28 09:22:15
"""
import torch
from torch import nn, optim
from torch.autograd import Variable
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(480, 128),
nn.ReLU(True),
nn.Linear(128, 64),
nn.ReLU(True),
nn.Linear(64, 32)
)
self.decoder = nn.Sequential(nn.Linear(32, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True),
nn.Linear(128, 480)
)
def forward(self, x):
encode = self.encoder(x)
decode = self.decoder(encode)
return encode, decode