Skip to content

Commit

Permalink
numpy type
Browse files Browse the repository at this point in the history
  • Loading branch information
forestagostinelli committed Jun 20, 2024
1 parent ac3355e commit a2805e8
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish_to_pypi.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Publish Python distribution to TestPyPI
name: Publish Python distribution to PyPI

on: push

Expand Down
6 changes: 3 additions & 3 deletions deepxube/environments/cube3.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def next_state(self, states: List[Cube3State], actions: List[Cube3Action]) -> Tu
states_np = np.stack([x.colors for x in states], axis=0)

states_next_np = np.zeros(states_np.shape, dtype=np.uint8)
tcs_np: NDArray[np.float_] = np.zeros(len(states))
tcs_np: NDArray[np.float64] = np.zeros(len(states))
for action in set(actions):
action_idxs: NDArray[np.int_] = np.array([idx for idx in range(len(actions)) if actions[idx] == action])
states_np_act = states_np[action_idxs]
Expand Down Expand Up @@ -589,7 +589,7 @@ def pddl_action_to_action(self, pddl_action: str) -> int:
assert match is not None
return int(match.group(1))

def visualize(self, states: Union[List[Cube3State], List[Cube3Goal]]) -> NDArray[np.float_]:
def visualize(self, states: Union[List[Cube3State], List[Cube3Goal]]) -> NDArray[np.float64]:
# initialize
fig = plt.figure(figsize=(.64, .64))
viz = InteractiveCube(3, self.get_start_states(1)[0].colors)
Expand All @@ -600,7 +600,7 @@ def visualize(self, states: Union[List[Cube3State], List[Cube3Goal]]) -> NDArray
width = int(width)
height = int(height)

states_img: NDArray[np.float_] = np.zeros((len(states), width, height, 6))
states_img: NDArray[np.float64] = np.zeros((len(states), width, height, 6))
for state_idx, state in enumerate(states):
# create image
if isinstance(state, Cube3State):
Expand Down
4 changes: 2 additions & 2 deletions deepxube/environments/environment_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def pddl_action_to_action(self, pddl_action: str) -> int:
pass

@abstractmethod
def visualize(self, states: Union[List[S], List[G]]) -> NDArray[np.float_]:
def visualize(self, states: Union[List[S], List[G]]) -> NDArray[np.float64]:
""" Implement if visualizing states. If you are planning on visualizing states, you do not have to implement
this (raise NotImplementedError).
Expand Down Expand Up @@ -311,7 +311,7 @@ def sample_goal(self, states_start: List[S], states_goal: List[S]) -> List[G]:
models_g: List[Model] = []

models_s: List[Model] = self.state_to_model(states_goal)
keep_probs: NDArray[np.float_] = np.random.rand(len(states_goal))
keep_probs: NDArray[np.float64] = np.random.rand(len(states_goal))
for model_s, keep_prob in zip(models_s, keep_probs):
rand_subset: Set[Atom] = misc_utils.random_subset(model_s, keep_prob)
models_g.append(frozenset(rand_subset))
Expand Down
8 changes: 4 additions & 4 deletions deepxube/environments/n_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def next_state(self, states: List[NPState], actions: List[NPAction]) -> Tuple[Li
z_idxs: NDArray[np.int_]
_, z_idxs = np.where(states_next_np == 0)

tcs_np: NDArray[np.float_] = np.zeros(len(states))
tcs_np: NDArray[np.float64] = np.zeros(len(states))
for action in set(actions):
action_idxs: NDArray[np.int_] = np.array([idx for idx in range(len(actions)) if actions[idx] == action])
states_np_act = states_np[action_idxs]
Expand All @@ -167,7 +167,7 @@ def expand(self, states: List[NPState]) -> Tuple[List[List[NPState]], List[List[
states_exp: List[List[NPState]] = [[] for _ in range(len(states))]
actions_exp_l: List[List[NPAction]] = [[] for _ in range(len(states))]

tc: NDArray[np.float_] = np.empty([num_states, self.num_actions])
tc: NDArray[np.float64] = np.empty([num_states, self.num_actions])

# numpy states
states_np: NDArray[int_t] = np.stack([state.tiles for state in states])
Expand Down Expand Up @@ -389,7 +389,7 @@ def pddl_action_to_action(self, pddl_action: str) -> int:

raise ValueError(f"Unknown action {pddl_action}")

def visualize(self, states: Union[List[NPState], List[NPGoal]]) -> NDArray[np.float_]:
def visualize(self, states: Union[List[NPState], List[NPGoal]]) -> NDArray[np.float64]:
fig = plt.figure(figsize=(.64, .64))
ax = fig.add_axes((0, 0, 1., 1.))
# fig = plt.figure(figsize=(.64, .64))
Expand All @@ -399,7 +399,7 @@ def visualize(self, states: Union[List[NPState], List[NPGoal]]) -> NDArray[np.fl
width, height = fig.get_size_inches() * fig.get_dpi()
width = int(width)
height = int(height)
states_img: NDArray[np.float_] = np.zeros((len(states), width, height, 3))
states_img: NDArray[np.float64] = np.zeros((len(states), width, height, 3))
for state_idx, state in enumerate(states):
ax.clear()

Expand Down
4 changes: 2 additions & 2 deletions deepxube/environments/sokoban.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,8 @@ def get_render_array(self, state: Union[SokobanState, SokobanGoal]) -> NDArray[n

return state_rendered

def visualize(self, states: Union[List[SokobanState], List[SokobanGoal]]) -> NDArray[np.float_]:
states_img: NDArray[np.float_] = np.zeros((len(states), self.img_dim, self.img_dim, 3))
def visualize(self, states: Union[List[SokobanState], List[SokobanGoal]]) -> NDArray[np.float64]:
states_img: NDArray[np.float64] = np.zeros((len(states), self.img_dim, self.img_dim, 3))

from PIL import Image
if self._surfaces is None:
Expand Down
8 changes: 4 additions & 4 deletions deepxube/nnet/nnet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.multiprocessing import Queue, get_context
from multiprocessing.process import BaseProcess

HeurFN_T = Callable[[Union[List[State], List[NDArray[Any]]], Optional[List[Goal]]], NDArray[np.float_]]
HeurFN_T = Callable[[Union[List[State], List[NDArray[Any]]], Optional[List[Goal]]], NDArray[np.float64]]


# training
Expand Down Expand Up @@ -67,8 +67,8 @@ def get_heuristic_fn(nnet: nn.Module, device: torch.device, env: Environment, cl
nnet.eval()

def heuristic_fn(states: Union[List[State], List[NDArray[Any]]],
goals: Optional[List[Goal]]) -> NDArray[np.float_]:
cost_to_go_l: List[NDArray[np.float_]] = []
goals: Optional[List[Goal]]) -> NDArray[np.float64]:
cost_to_go_l: List[NDArray[np.float64]] = []

num_states: int
is_nnet_format: bool
Expand Down Expand Up @@ -100,7 +100,7 @@ def heuristic_fn(states: Union[List[State], List[NDArray[Any]]],
# get nnet output
states_goals_nnet_batch_tensors = to_pytorch_input(states_goals_nnet_batch, device)

cost_to_go_batch: NDArray[np.float_] = nnet(states_goals_nnet_batch_tensors).cpu().data.numpy()
cost_to_go_batch: NDArray[np.float64] = nnet(states_goals_nnet_batch_tensors).cpu().data.numpy()
if is_v:
cost_to_go_batch = cost_to_go_batch[:, 0]
cost_to_go_l.append(cost_to_go_batch)
Expand Down
6 changes: 3 additions & 3 deletions deepxube/search/astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,18 @@ def add_to_open(instances: List[Instance], nodes: List[List[Node]]) -> None:


def add_heuristic_and_cost(nodes: List[Node], heuristic_fn: HeurFN_T,
weights: List[float]) -> Tuple[NDArray[np.float_], NDArray[np.float_]]:
weights: List[float]) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
if len(nodes) == 0:
return np.zeros(0), np.zeros(0)

# compute node cost
states: List[State] = [node.state for node in nodes]
goals: List[Goal] = [node.goal for node in nodes]
heuristics = heuristic_fn(states, goals)
path_costs: NDArray[np.float_] = np.array([node.path_cost for node in nodes])
path_costs: NDArray[np.float64] = np.array([node.path_cost for node in nodes])
is_solved: NDArray[np.bool_] = np.array([node.is_solved for node in nodes])

costs: NDArray[np.float_] = np.array(weights) * path_costs + heuristics * np.logical_not(is_solved)
costs: NDArray[np.float64] = np.array(weights) * path_costs + heuristics * np.logical_not(is_solved)

# add cost to node
for node, heuristic, cost in zip(nodes, heuristics, costs):
Expand Down
8 changes: 4 additions & 4 deletions deepxube/search/greedy_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def add_instances(self, states: List[State], goals: List[Goal], eps_l: Optional[
self.instances.append(instance)

def step(self, heuristic_fn: HeurFN_T, times: Optional[Times] = None,
rand_seen: bool = False) -> Tuple[List[State], List[Goal], NDArray[np.float_]]:
rand_seen: bool = False) -> Tuple[List[State], List[Goal], NDArray[np.float64]]:
if times is None:
times = Times()

Expand Down Expand Up @@ -80,7 +80,7 @@ def step(self, heuristic_fn: HeurFN_T, times: Optional[Times] = None,
# get next state
if not is_solved[idx]:
state_exp: List[State] = states_exp[idx]
ctg_next_p_tc: NDArray[np.float_] = ctg_next_p_tcs[idx]
ctg_next_p_tc: NDArray[np.float64] = ctg_next_p_tcs[idx]

state_next: State = state_exp[int(np.argmin(ctg_next_p_tc))]
seen_state: bool = state_next in instance.seen_states
Expand Down Expand Up @@ -139,7 +139,7 @@ def greedy_runner(env: Environment, heur_fn_q: HeurFnQ, proc_id: int,
num_steps_all: NDArray[np.int_] = np.array([instance.step_num for instance in greedy.instances])

# Get state cost-to-go
state_ctg_all: NDArray[np.float_] = heuristic_fn(states, goals)
state_ctg_all: NDArray[np.float64] = heuristic_fn(states, goals)

results_queue.put((proc_id, is_solved_all, num_steps_all, state_ctg_all, inst_gen_steps))

Expand Down Expand Up @@ -206,7 +206,7 @@ def greedy_test(states: List[State], goals: List[Goal], inst_gen_steps: List[int

is_solved: NDArray[np.bool_] = is_solved_all[step_idxs]
num_steps: NDArray[np.int_] = num_steps_all[step_idxs]
ctgs: NDArray[np.float_] = ctgs_all[step_idxs]
ctgs: NDArray[np.float64] = ctgs_all[step_idxs]

# Get stats
per_solved = 100 * float(sum(is_solved)) / float(len(is_solved))
Expand Down
12 changes: 6 additions & 6 deletions deepxube/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def is_valid_soln(state: State, goal: Goal, soln: List[Action], env: Environment

def bellman(states: List[State], goals: List[Goal], heuristic_fn,
env: Environment,
times: Optional[Times] = None) -> Tuple[NDArray[np.float_], List[NDArray[np.float_]], List[List[State]],
times: Optional[Times] = None) -> Tuple[NDArray[np.float64], List[NDArray[np.float64]], List[List[State]],
List[bool]]:
if times is None:
times = Times()
Expand All @@ -35,7 +35,7 @@ def bellman(states: List[State], goals: List[Goal], heuristic_fn,
for goal, state_exp in zip(goals, states_exp):
goals_flat.extend([goal] * len(state_exp))

ctg_next_flat: NDArray[np.float_] = heuristic_fn(states_exp_flat, goals_flat)
ctg_next_flat: NDArray[np.float64] = heuristic_fn(states_exp_flat, goals_flat)
times.record_time("heur", time.time() - start_time)

# is solved
Expand All @@ -45,11 +45,11 @@ def bellman(states: List[State], goals: List[Goal], heuristic_fn,

# backup
start_time = time.time()
tcs_flat: NDArray[np.float_] = np.hstack(tcs_l)
ctg_next_p_tc_flat: NDArray[np.float_] = tcs_flat + ctg_next_flat
ctg_next_p_tc_l: List[NDArray[np.float_]] = np.split(ctg_next_p_tc_flat, split_idxs)
tcs_flat: NDArray[np.float64] = np.hstack(tcs_l)
ctg_next_p_tc_flat: NDArray[np.float64] = tcs_flat + ctg_next_flat
ctg_next_p_tc_l: List[NDArray[np.float64]] = np.split(ctg_next_p_tc_flat, split_idxs)

ctg_backup: NDArray[np.float_] = np.array([np.min(x) for x in ctg_next_p_tc_l]) * np.logical_not(is_solved)
ctg_backup: NDArray[np.float64] = np.array([np.min(x) for x in ctg_next_p_tc_l]) * np.logical_not(is_solved)
times.record_time("backup", time.time() - start_time)

return ctg_backup, ctg_next_p_tc_l, states_exp, is_solved
10 changes: 5 additions & 5 deletions deepxube/training/avi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, env: Environment, step_max: int, num_test_per_step: int):

def do_update(step_max: int, update_num: int, env: Environment, step_update_max: int, num_states: int,
eps_max: float, heur_fn_qs: List[HeurFnQ],
update_batch_size: int) -> Tuple[List[NDArray[Any]], NDArray[np.float_]]:
update_batch_size: int) -> Tuple[List[NDArray[Any]], NDArray[np.float64]]:
update_steps: int = int(min(update_num + 1, step_update_max))
# num_states: int = int(np.ceil(num_states / update_steps))

Expand All @@ -72,7 +72,7 @@ def do_update(step_max: int, update_num: int, env: Environment, step_update_max:
"greedy", update_batch_size=update_batch_size, eps_max=eps_max)

nnet_rep: List[NDArray[Any]]
ctgs: NDArray[np.float_]
ctgs: NDArray[np.float64]
nnet_rep, ctgs, is_solved = updater.update()

# Print stats
Expand Down Expand Up @@ -107,8 +107,8 @@ def load_data(model_dir: str, nnet_file: str, env: Environment, num_test_per_ste
return nnet, status


def make_batches(nnet_rep: List[NDArray[Any]], ctgs: NDArray[np.float_],
batch_size: int) -> List[Tuple[List[NDArray[Any]], NDArray[np.float_]]]:
def make_batches(nnet_rep: List[NDArray[Any]], ctgs: NDArray[np.float64],
batch_size: int) -> List[Tuple[List[NDArray[Any]], NDArray[np.float64]]]:
num_examples = ctgs.shape[0]
rand_idxs = np.random.choice(num_examples, num_examples, replace=False)
ctgs = ctgs.astype(np.float32)
Expand All @@ -130,7 +130,7 @@ def make_batches(nnet_rep: List[NDArray[Any]], ctgs: NDArray[np.float_],
return batches


def train_nnet(nnet: nn.Module, nnet_rep: List[NDArray[Any]], ctgs: NDArray[np.float_], device: torch.device,
def train_nnet(nnet: nn.Module, nnet_rep: List[NDArray[Any]], ctgs: NDArray[np.float64], device: torch.device,
batch_size: int, num_itrs: int, train_itr: int, lr: float, lr_d: float, display_itrs: int) -> float:
# optimization
criterion = nn.MSELoss()
Expand Down
10 changes: 5 additions & 5 deletions deepxube/updaters/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,20 @@ def __init__(self, env: Environment, num_states: int, back_max: int, step_probs:
proc.start()
self.procs.append(proc)

def update(self) -> Tuple[List[NDArray[Any]], NDArray[np.float_], NDArray[np.bool_]]:
def update(self) -> Tuple[List[NDArray[Any]], NDArray[np.float64], NDArray[np.bool_]]:
states_goals_update_nnet: List[NDArray[Any]]
cost_to_go_update: NDArray[np.float_]
cost_to_go_update: NDArray[np.float64]
is_solved: NDArray[np.bool_]
states_goals_update_nnet, cost_to_go_update, is_solved = self._update()

output_update = np.expand_dims(cost_to_go_update, 1)

return states_goals_update_nnet, output_update, is_solved

def _update(self) -> Tuple[List[NDArray[Any]], NDArray[np.float_], NDArray[np.bool_]]:
def _update(self) -> Tuple[List[NDArray[Any]], NDArray[np.float64], NDArray[np.bool_]]:
# process results
states_goals_update_nnet_l: List[List[NDArray[Any]]] = []
cost_to_go_update_l: List[NDArray[np.float_]] = []
cost_to_go_update_l: List[NDArray[np.float64]] = []
is_solved_l: List[NDArray[np.bool_]] = []

none_count: int = 0
Expand Down Expand Up @@ -182,7 +182,7 @@ def _update(self) -> Tuple[List[NDArray[Any]], NDArray[np.float_], NDArray[np.bo
axis=0)
states_goals_update_nnet.append(states_goals_nnet_idx)

cost_to_go_update: NDArray[np.float_] = np.concatenate(cost_to_go_update_l, axis=0)
cost_to_go_update: NDArray[np.float64] = np.concatenate(cost_to_go_update_l, axis=0)
is_solved: NDArray[np.bool_] = np.concatenate(is_solved_l, axis=0)

for proc in self.procs:
Expand Down
6 changes: 3 additions & 3 deletions deepxube/utils/misc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def boltzmann(vals: List[float], temp: float) -> List[float]:
if len(vals) == 1:
return [1.0]
else:
vals_np: NDArray[np.float_] = np.array(vals)
exp_vals_np: NDArray[np.float_] = np.exp((1.0 / temp) * (vals_np - np.max(vals_np)))
probs_np: NDArray[np.float_] = exp_vals_np / np.sum(exp_vals_np)
vals_np: NDArray[np.float64] = np.array(vals)
exp_vals_np: NDArray[np.float64] = np.exp((1.0 / temp) * (vals_np - np.max(vals_np)))
probs_np: NDArray[np.float64] = exp_vals_np / np.sum(exp_vals_np)

return list(probs_np)

0 comments on commit a2805e8

Please sign in to comment.