-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
125 lines (106 loc) · 3.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import datetime
from dataset.dataset_manager import DatasetManager
from utils.monitoring import Monitor
def make_folder_name() -> str:
"""
Generate current time as string.
Returns:
str: current time
"""
NOWTIMES = datetime.datetime.now()
curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S/")
return curr_time
def get_model_class(model: str):
if model == "dcgan":
from models.dcgan import DCGAN as model_class
elif model == "can":
from models.can import CAN as model_class
else:
raise NotImplementedError
return model_class
def train(args):
path_to_save = "checkpoints/" + make_folder_name() if args.save else None
train_dataset_loader = DatasetManager(
dataset_path=args.dataset_path, batch_size=args.batch_size
).get_dataset_loader()
model_class = get_model_class(args.model)
model = model_class(batch_size=args.batch_size)
if args.load_from is not None:
model.load_model(args.load_from, args.load_only_generator)
if args.monitoring:
monitor = Monitor(args.batch_size)
status_str = "[{}/{} episode] generator loss : {} | discriminator_loss : {}"
for ep in range(1, args.epoch + 1):
generator_losses = []
discriminator_losses = []
for images, _ in train_dataset_loader:
fake_images = model.generate_fake_images()
if args.monitoring:
monitor.monitor_images(fake_images, args.monitoring_interval)
generator_loss = model.train_generator(fake_images)
generator_losses.append(generator_loss.item())
discriminator_loss = model.train_discriminator(images, fake_images)
discriminator_losses.append(discriminator_loss.item())
print(
status_str.format(
ep,
args.epoch,
sum(generator_losses) / len(generator_losses),
sum(discriminator_losses) / len(discriminator_losses),
)
)
if args.save and ep % args.save_interval == 0:
model.save_model(path_to_save, "checkpoint_{}.pt".format(ep))
def test(args):
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Creative Adversarial Network")
# model
parser.add_argument(
"--model",
default="dcgan",
type=str,
help="Model to train or test (ex. 'dcgan' or 'can')",
)
# data
parser.add_argument(
"--dataset-path", default="data/wikiart/", type=str, help="Wikiart dataest path"
)
# checkpointing
parser.add_argument("--save", action="store_true", help="Whether to save the model")
parser.add_argument(
"--save-interval", type=int, default=20, help="Model save interval"
)
parser.add_argument("--load-from", type=str, help="Path to load the model")
parser.add_argument(
"--load-only-generator",
action="store_true",
help="Whether to load only the generator",
)
# train
parser.add_argument("--epoch", type=int, default=100, help="Learning epoch")
parser.add_argument("--batch-size", type=int, default=128, help="Learning epoch")
# test
parser.add_argument("--test", action="store_true", help="Whether to test the model")
# monitoring
parser.add_argument(
"--monitoring",
action="store_true",
help="Whether to visualize the generated images by generator",
)
parser.add_argument(
"--monitoring-interval",
type=int,
default=1,
help="Sleep for interval seconds instead.",
)
# gpu
parser.add_argument(
"--device", type=str, default="cuda:0", help="Device to use (ex. cpu, mps, ctc)"
)
args = parser.parse_args()
if not args.test:
train(args)
else:
test(args)