Skip to content

Commit

Permalink
weights_only
Browse files Browse the repository at this point in the history
  • Loading branch information
forestagostinelli committed Nov 23, 2024
1 parent 7eac594 commit 4288b1f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion deepxube/environments/n_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_v_nnet(self) -> HeurFnNNet:
return nnet

def get_q_nnet(self) -> HeurFnNNet:
nnet = NNet(self.num_tiles, self.num_tiles + 1, 5000, 1000, 4, self.num_actions, True, False, "V")
nnet = NNet(self.num_tiles, self.num_tiles + 1, 5000, 1000, 4, self.num_actions, True, False, "Q")
return nnet

def get_start_states(self, num_states: int) -> List[NPState]:
Expand Down
4 changes: 2 additions & 2 deletions deepxube/nnet/nnet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def get_device() -> Tuple[torch.device, List[int], bool]:
def load_nnet(model_file: str, nnet: nn.Module, device: Optional[torch.device] = None) -> nn.Module:
# get state dict
if device is None:
state_dict = torch.load(model_file)
state_dict = torch.load(model_file, weights_only=True)
else:
state_dict = torch.load(model_file, map_location=device)
state_dict = torch.load(model_file, map_location=device, weights_only=False)

# remove module prefix
new_state_dict = OrderedDict()
Expand Down

0 comments on commit 4288b1f

Please sign in to comment.