Skip to content

Commit

Permalink
Build working validity checker and plotter
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdh7 committed Feb 13, 2025
1 parent 8ebab03 commit 4ea6cac
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 11 deletions.
109 changes: 98 additions & 11 deletions PyCO2SYS/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,12 @@ def __init__(self, pd_index=None, xr_dims=None, xr_shape=None, **kwargs):
2: "xkcd:azure", # calculated en route to a user-requested parameter
3: "xkcd:tangerine", # calculated after direct user request
}
self.valid_colours = {
-1: "xkcd:light red", # invalid
0: "xkcd:light grey", # unknown
1: "xkcd:sky blue", # valid
}
self.checked_valid = False

def __getitem__(self, key):
# When the user requests a dict key that hasn't been solved for yet, then
Expand Down Expand Up @@ -1064,6 +1070,7 @@ def _assemble(self, icase, data):
for k, v in values_default.items():
if k not in data and k in graph.nodes:
data[k] = v
nx.set_node_attributes(graph, {k: 1}, name="state")
self.nodes_original = list(k for k, v in data.items() if v is not None)
return graph, funcs, data

Expand Down Expand Up @@ -1570,7 +1577,7 @@ def propagate(self, uncertainty_in, uncertainty_from):
New entries are added in the `uncertainty` attribute, for example:
co2s = CO2System(dic=2100, alkalinity=2300)
co2s = pyco2.sys(dic=2100, alkalinity=2300)
co2s.propagate("pH", {"dic": 2, "alkalinity": 2})
co2s.uncertainty["pH"]["total"] # total uncertainty in pH
co2s.uncertainty["pH"]["dic"] # component of ^ due to DIC uncertainty
Expand Down Expand Up @@ -1738,6 +1745,7 @@ def plot_graph(
node_kwargs=None,
edge_kwargs=None,
label_kwargs=None,
mode="state",
):
"""Draw a graph showing the relationships between the different parameters.
Expand Down Expand Up @@ -1785,6 +1793,8 @@ def plot_graph(
#
if ax is None:
ax = plt.subplots(dpi=300, figsize=(8, 7))[1]
if mode == "valid" and not self.checked_valid:
self.check_valid()
graph_to_plot = self.get_graph_to_plot(
exclude_nodes=exclude_nodes,
show_tsp=show_tsp,
Expand All @@ -1802,14 +1812,35 @@ def plot_graph(
nx_args=nx_args,
nx_kwargs=nx_kwargs,
)
node_states = nx.get_node_attributes(graph_to_plot, "state", default=0)
edge_states = nx.get_edge_attributes(graph_to_plot, "state", default=0)
node_colour = [
self.state_colours[node_states[n]] for n in nx.nodes(graph_to_plot)
]
edge_colour = [
self.state_colours[edge_states[e]] for e in nx.edges(graph_to_plot)
]
if mode == "state":
node_states = nx.get_node_attributes(graph_to_plot, "state", default=0)
edge_states = nx.get_edge_attributes(graph_to_plot, "state", default=0)
node_colour = [
self.state_colours[node_states[n]] for n in nx.nodes(graph_to_plot)
]
edge_colour = [
self.state_colours[edge_states[e]] for e in nx.edges(graph_to_plot)
]
elif mode == "valid":
node_valid = nx.get_node_attributes(graph_to_plot, "valid", default=0)
edge_valid = nx.get_edge_attributes(graph_to_plot, "valid", default=0)
node_valid_p = nx.get_node_attributes(graph_to_plot, "valid_p", default=0)
node_colour = [
self.valid_colours[node_valid[n]] for n in nx.nodes(graph_to_plot)
]
edge_colour = [
self.valid_colours[edge_valid[e]] for e in nx.edges(graph_to_plot)
]
node_edgecolors = [
self.valid_colours[node_valid_p[n]] for n in nx.nodes(graph_to_plot)
]
node_linewidths = [[0, 2][node_valid_p[n]] for n in nx.nodes(graph_to_plot)]
else:
warnings.warn(
f'mode "{mode}" not recognised, options are "state", "valid".'
)
node_colour = "xkcd:grey"
edge_colour = "xkcd:grey"
node_labels = {k: k for k in graph_to_plot.nodes}
for k, v in set_node_labels.items():
if k in node_labels:
Expand All @@ -1820,6 +1851,9 @@ def plot_graph(
edge_kwargs = {}
if label_kwargs is None:
label_kwargs = {}
if mode == "valid":
node_kwargs["edgecolors"] = node_edgecolors
node_kwargs["linewidths"] = node_linewidths
nx.draw_networkx_nodes(
graph_to_plot,
ax=ax,
Expand All @@ -1844,10 +1878,63 @@ def plot_graph(
return ax

def keys_all(self):
"""Return a list of all possible results keys, including those that have
"""Return a tuple of all possible results keys, including those that have
not yet been solved for.
"""
return list(self.graph.nodes)
return tuple(self.graph.nodes)

def check_valid(self):
"""Check which parameters are valid."""
for n in nx.topological_sort(self.graph):
# First, assign validity for functions that do have valid ranges
# (shown by node fill colour on the graph plot)
if n in self.funcs and hasattr(self.funcs[n], "valid"):
for p, p_range in self.funcs[n].valid.items():
n_valid = []
# If all predecessor parameters fall within valid ranges, it's valid
if np.all(
(self.data[p] >= p_range[0]) & (self.data[p] <= p_range[1])
):
n_valid.append(1)
nx.set_edge_attributes(
self.graph,
{(p, n): 1},
name="valid",
)
# If any predecessor parameter is outside any range, it's invalid
else:
n_valid.append(-1)
nx.set_edge_attributes(
self.graph,
{(p, n): -1},
name="valid",
)
nx.set_node_attributes(
self.graph,
{n: min(n_valid)},
name="valid",
)
# Next, assign inherited validity
# (shown by node edge colour on the graph plot)
n_valid_p = []
for p in self.graph.predecessors(n):
p_attrs = self.graph.nodes[p]
for v in ["valid", "valid_p"]:
if v in p_attrs:
n_valid_p.append(p_attrs[v])
if p_attrs[v] == -1:
nx.set_edge_attributes(
self.graph,
{(p, n): -1},
name="valid",
)
if -1 in n_valid_p:
nx.set_node_attributes(
self.graph,
{n: -1},
name="valid_p",
)
self.checked_valid = True


def sys(data=None, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions PyCO2SYS/equilibria/p1atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def k_CO2_W74(temperature, salinity):
return np.exp(lnK0)


@valid(temperature=[0, 45], salinity=[5, 45])
def k_HSO4_free_D90a(temperature, salinity, ionic_strength):
"""Bisulfate dissociation constant in mol/kg-sw on the free scale following D90a.
Used when opt_k_HSO4 = 1.
Expand Down
19 changes: 19 additions & 0 deletions tests/_test_valid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# %%
import networkx as nx
import numpy as np

import PyCO2SYS as pyco2

co2s = pyco2.sys(salinity=42)
co2s.solve(["k_HSO4_free", "k_CO2", "k_H2CO3"], store_steps=2)

# Visualise
co2s.plot_graph(
prog_graphviz="dot",
show_unknown=False,
show_isolated=False,
mode="valid",
node_kwargs={"node_size": 700},
edge_kwargs={"node_size": 700},
label_kwargs={"font_size": 9},
)

0 comments on commit 4ea6cac

Please sign in to comment.