Skip to content

Commit

Permalink
Remove redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
misaghsoltani committed Sep 3, 2024
1 parent 12c4dd8 commit 7e5b11d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 123 deletions.
2 changes: 0 additions & 2 deletions deepcubeai/training/train_env_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,6 @@ def main():
print(f"device: {device}, devices: {devices}, on_gpu: {on_gpu}")

nnet_model.to(device)
if on_gpu and False:
nnet_model = nn.DataParallel(nnet_model)

# load data
print("Loading data ...")
Expand Down
141 changes: 20 additions & 121 deletions deepcubeai/training/train_env_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import sys
import time
from argparse import ArgumentParser
from copy import deepcopy
from typing import Any, Dict, List, OrderedDict, Tuple
from typing import Any, Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -82,13 +81,9 @@ def parse_arguments(parser: ArgumentParser) -> Dict[str, Any]:
train_dir: str = f"{args_dict['save_dir']}/{args_dict['nnet_name']}/"
args_dict["train_dir"] = train_dir
args_dict["nnet_model_dir"] = f"{args_dict['train_dir']}/"
best_model_dir = f"{args_dict['nnet_model_dir']}/best_model"
if not os.path.exists(args_dict["nnet_model_dir"]):
os.makedirs(args_dict["nnet_model_dir"])

if not os.path.exists(best_model_dir) and False:
os.makedirs(best_model_dir)

if not os.path.exists(f"{args_dict['nnet_model_dir']}/pics"):
os.makedirs(f"{args_dict['nnet_model_dir']}/pics")

Expand Down Expand Up @@ -132,49 +127,21 @@ def load_nnet(nnet_dir: str, env: Environment) -> Tuple[nn.Module, nn.Module, nn
return encoder, decoder, env_model


def load_best_model_info(nnet_model_dir: str) -> Dict[str, Any]:
"""Loads the best model information.
Args:
nnet_model_dir (str): Directory of the neural network model.
Returns:
Dict[str, Any]: Best model information.
"""
best_model_info_file: str = f"{nnet_model_dir}/best_model/model_info.pkl"
best_model_info = None
if os.path.isfile(best_model_info_file):
with open(best_model_info_file, "rb") as f:
best_model_info: Dict[str, Any] = pickle.load(f)

return best_model_info


def load_train_state(
nnet_model_dir: str,
env: Environment) -> Tuple[nn.Module, nn.Module, nn.Module, int, Dict[str, Any]]:
def load_train_state(nnet_model_dir: str,
env: Environment) -> Tuple[nn.Module, nn.Module, nn.Module]:
"""Loads the training state.
Args:
nnet_model_dir (str): Directory of the neural network model.
env (Environment): The environment instance.
Returns:
Tuple[nn.Module, nn.Module, nn.Module, int, Dict[str, Any]]: Encoder, decoder, environment
model, iteration, and best model information.
Tuple[nn.Module, nn.Module, nn.Module]: Encoder, decoder, environment
model, and iteration information.
"""
best_model_itr_file: str = f"{nnet_model_dir}/best_model/train_itr.pkl"
best_model_info_file: str = f"{nnet_model_dir}/best_model/model_info.pkl"
itr_file: str = f"{nnet_model_dir}/train_itr.pkl"
best_model_info = None
if os.path.isfile(best_model_itr_file) and False:
with open(best_model_itr_file, "rb") as f:
itr: int = pickle.load(f)
nnet_model_dir: str = f"{nnet_model_dir}/best_model"
with open(best_model_info_file, "rb") as f:
best_model_info: Dict[str, Any] = pickle.load(f)

elif os.path.isfile(itr_file):

if os.path.isfile(itr_file):
with open(itr_file, "rb") as f:
itr: int = pickle.load(f) + 1

Expand All @@ -183,7 +150,7 @@ def load_train_state(

encoder, decoder, env_model = load_nnet(nnet_model_dir, env)

return encoder, decoder, env_model, itr, best_model_info
return encoder, decoder, env_model, itr


class Autoencoder(nn.Module):
Expand Down Expand Up @@ -401,8 +368,7 @@ def train_nnet(autoencoder: Autoencoder, env_nnet: nn.Module, nnet_dir: str,
state_episodes_train: List[np.ndarray], action_episodes_train: List[List[int]],
state_episodes_val: List[np.ndarray], action_episodes_val: List[List[int]],
device: torch.device, batch_size: int, num_itrs: int, start_itr: int, lr: float,
lr_d: float, env_coeff: float, only_env: bool, best_model_info: Dict[str,
Any]) -> None:
lr_d: float, env_coeff: float, only_env: bool) -> None:
"""Trains the neural network.
Args:
Expand All @@ -421,26 +387,8 @@ def train_nnet(autoencoder: Autoencoder, env_nnet: nn.Module, nnet_dir: str,
lr_d (float): Learning rate decay.
env_coeff (float): Environment coefficient.
only_env (bool): Whether to train only the environment model.
best_model_info (Dict[str, Any]): Information about the best model.
"""
# initialize
if best_model_info is None:
best_model_info = {
'l_env': float('inf'),
'l_env_val': float('inf'),
'l_recon': float('inf'),
'l_recon_val': float('inf'),
'loss': float('inf'),
'loss_val': float('inf'),
'itr': -1,
'max_itr': -1,
'lr': -1,
'env_coeff': -1
}
best_model_updated: bool = False
best_env_net_state_dict: OrderedDict[str, Tensor]
best_encoder_state_dict: OrderedDict[str, Tensor]
best_decoder_state_dict: OrderedDict[str, Tensor]
env_nnet.train()
autoencoder.train()
episode_lens_train: np.array = np.array(
Expand Down Expand Up @@ -592,59 +540,22 @@ def train_nnet(autoencoder: Autoencoder, env_nnet: nn.Module, nnet_dir: str,
plt.title("Reconstruction")
plt.close()

if True:
# Access the original modules if wrapped with DataParallel
env_nnet_module = env_nnet.module if isinstance(env_nnet,
nn.DataParallel) else env_nnet
encoder_module = autoencoder.module.encoder if isinstance(
autoencoder, nn.DataParallel) else autoencoder.encoder
decoder_module = autoencoder.module.decoder if isinstance(
autoencoder, nn.DataParallel) else autoencoder.decoder

torch.save(env_nnet_module.state_dict(), f"{nnet_dir}/env_state_dict.pt")
torch.save(encoder_module.state_dict(), f"{nnet_dir}/encoder_state_dict.pt")
torch.save(decoder_module.state_dict(), f"{nnet_dir}/decoder_state_dict.pt")
with open(f"{nnet_dir}/train_itr.pkl", "wb") as f:
pickle.dump(train_itr, f, protocol=-1)

if env_coeff == 0.5 and ((best_model_info['loss_val'] + best_model_info['loss']) / 2
> (loss_val.item() + loss.item()) / 2) and False:
# and ((best_model_info['l_env_val'] >= loss_env_val.item()) or (best_model_info['l_recon_val'] >= loss_recon_val.item()))
# and ((best_model_info['l_env'] >= loss_env.item()) or (best_model_info['l_recon'] >= loss_recon.item()))):
# if env_coeff == 0.5 and (best_model_info['l_env_val'] >= loss_env_val.item()
# and best_model_info['l_recon_val'] >= loss_recon_val.item()):
best_model_info['l_env'] = loss_env.item()
best_model_info['l_recon'] = loss_recon.item()
best_model_info['loss'] = loss.item()
best_model_info['l_env_val'] = loss_env_val.item()
best_model_info['l_recon_val'] = loss_recon_val.item()
best_model_info['loss_val'] = loss_val.item()
best_model_info['itr'] = train_itr
best_model_info['max_itr'] = num_itrs
best_model_info['lr'] = lr
best_model_info['env_coeff'] = env_coeff
best_env_net_state_dict = deepcopy(env_nnet_module.state_dict())
best_encoder_state_dict = deepcopy(encoder_module.state_dict())
best_decoder_state_dict = deepcopy(decoder_module.state_dict())
best_model_updated = True
env_nnet_module = env_nnet
encoder_module = autoencoder.encoder
decoder_module = autoencoder.decoder

torch.save(env_nnet_module.state_dict(), f"{nnet_dir}/env_state_dict.pt")
torch.save(encoder_module.state_dict(), f"{nnet_dir}/encoder_state_dict.pt")
torch.save(decoder_module.state_dict(), f"{nnet_dir}/decoder_state_dict.pt")
with open(f"{nnet_dir}/train_itr.pkl", "wb") as f:
pickle.dump(train_itr, f, protocol=-1)

print("")

env_nnet.train()
autoencoder.train()
start_time_all = time.time()

if best_model_updated:
torch.save(best_env_net_state_dict, f"{nnet_dir}/best_model/env_state_dict.pt")
torch.save(best_encoder_state_dict, f"{nnet_dir}/best_model/encoder_state_dict.pt")
torch.save(best_decoder_state_dict, f"{nnet_dir}/best_model/decoder_state_dict.pt")
with open(f"{nnet_dir}/best_model/train_itr.pkl", "wb") as f:
pickle.dump(best_model_info['itr'], f, protocol=-1)
with open(f"{nnet_dir}/best_model/model_info.pkl", "wb") as f:
pickle.dump(best_model_info, f, protocol=-1)
with open(f"{nnet_dir}/best_model/model_info.text", 'a') as file:
file.write(f"{best_model_info}\n")


def main():
"""Main function to run the training process."""
Expand Down Expand Up @@ -672,25 +583,13 @@ def main():
start_itr: int
encoder: nn.Module
decoder: nn.Module
best_model_info: Dict[str, Any]
encoder, decoder, env_nnet, start_itr, best_model_info = load_train_state(
args_dict['nnet_model_dir'], env)
encoder, decoder, env_nnet, start_itr, = load_train_state(args_dict['nnet_model_dir'], env)
env_nnet.to(device)
autoencoder: Autoencoder = Autoencoder(encoder, decoder)
autoencoder.to(device)

if on_gpu and len(devices) > 1:
env_nnet = nn.DataParallel(env_nnet)
autoencoder = nn.DataParallel(autoencoder)

print(f"Using {len(devices)} GPU(s): {devices}")

if best_model_info is not None:
itrs_num = 20000
args_dict['max_itrs'] = best_model_info['itr'] + itrs_num

best_model_info = load_best_model_info(args_dict['nnet_model_dir'])

print(f"Starting iteration: {start_itr}, Max iteration: {args_dict['max_itrs']}")

if args_dict['max_itrs'] <= start_itr:
Expand Down Expand Up @@ -723,7 +622,7 @@ def main():
train_nnet(autoencoder, env_nnet, args_dict['nnet_model_dir'], state_episodes_train,
action_episodes_train, state_episodes_val, action_episodes_val, device,
args_dict['batch_size'], args_dict['max_itrs'], start_itr, args_dict['lr'],
args_dict['lr_d'], args_dict['env_coeff'], args_dict['only_env'], best_model_info)
args_dict['lr_d'], args_dict['env_coeff'], args_dict['only_env'])

print("Testing after training")
test_model(encoder, env_nnet, state_episodes_val, action_episodes_val, device,
Expand Down

0 comments on commit 7e5b11d

Please sign in to comment.