Skip to content

Commit

Permalink
clearer error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 29, 2025
1 parent a9504a1 commit f9b2ef9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
3 changes: 2 additions & 1 deletion examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@
import torchpme
from torchpme.tuning import tune_pme

dtype = torch.float64

# %%
#
# Create the properties CsCl unit cell
dtype = torch.float64
symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=dtype)
Expand Down
21 changes: 12 additions & 9 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _validate_parameters(
) -> None:
if positions.dtype != dtype:
raise TypeError(
f"type of `positions` ({positions.dtype}) must be same as class "
f"type of `positions` ({positions.dtype}) must be same as the class "
f"type ({dtype})"
)

Expand All @@ -24,10 +24,13 @@ def _validate_parameters(

if positions.device.type != device:
raise ValueError(
f"device of `positions` ({positions.device}) must be same as class "
f"device of `positions` ({positions.device}) must be same as the class "
f"device ({device})"
)

# We use `positions.device` because it includes the device type AND index, which the
# `device` parameter may lack

# check shape, dtype and device of positions
num_atoms = len(positions)
if list(positions.shape) != [len(positions), 3]:
Expand All @@ -45,12 +48,12 @@ def _validate_parameters(

if cell.dtype != positions.dtype:
raise TypeError(
f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})"
)

if cell.device != positions.device:
raise ValueError(
f"device of `cell` ({cell.device}) must be same as `positions` ({device})"
f"device of `cell` ({cell.device}) must be same as the class ({device})"
)

if smearing is not None and torch.equal(
Expand Down Expand Up @@ -79,12 +82,12 @@ def _validate_parameters(

if charges.dtype != positions.dtype:
raise TypeError(
f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})"
f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})"
)

if charges.device != positions.device:
raise ValueError(
f"device of `charges` ({charges.device}) must be same as `positions` "
f"device of `charges` ({charges.device}) must be same as the class "
f"({device})"
)

Expand All @@ -99,7 +102,7 @@ def _validate_parameters(
if neighbor_indices.device != positions.device:
raise ValueError(
f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
f"same as `positions` ({device})"
f"same as the class ({device})"
)

if neighbor_distances.shape != neighbor_indices[:, 0].shape:
Expand All @@ -112,11 +115,11 @@ def _validate_parameters(
if neighbor_distances.device != positions.device:
raise ValueError(
f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
f"same as `positions` ({device})"
f"same as the class ({device})"
)

if neighbor_distances.dtype != positions.dtype:
raise TypeError(
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
f"as `positions` ({dtype})"
f"as the class ({dtype})"
)
21 changes: 9 additions & 12 deletions tests/calculators/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_compute_output_shapes():

def test_wrong_device_positions():
calculator = CalculatorTest()
match = r"device of `positions` \(meta\) must be same as class device \(cpu\)"
match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1.to(device="meta"),
Expand All @@ -59,7 +59,7 @@ def test_wrong_device_positions():
def test_wrong_dtype_positions():
calculator = CalculatorTest()
match = (
r"type of `positions` \(torch.float64\) must be same as class type "
r"type of `positions` \(torch.float64\) must be same as the class type "
r"\(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
Expand Down Expand Up @@ -108,8 +108,7 @@ def test_invalid_shape_cell():
def test_invalid_dtype_cell():
calculator = CalculatorTest()
match = (
r"type of `cell` \(torch.float64\) must be same as `positions` "
r"\(torch.float32\)"
r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
calculator.forward(
Expand All @@ -123,7 +122,7 @@ def test_invalid_dtype_cell():

def test_invalid_device_cell():
calculator = CalculatorTest()
match = r"device of `cell` \(meta\) must be same as `positions` \(cpu\)"
match = r"device of `cell` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand Down Expand Up @@ -189,7 +188,7 @@ def test_invalid_shape_charges():
def test_invalid_dtype_charges():
calculator = CalculatorTest()
match = (
r"type of `charges` \(torch.float64\) must be same as `positions` "
r"type of `charges` \(torch.float64\) must be same as the class "
r"\(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
Expand All @@ -204,7 +203,7 @@ def test_invalid_dtype_charges():

def test_invalid_device_charges():
calculator = CalculatorTest()
match = r"device of `charges` \(meta\) must be same as `positions` \(cpu\)"
match = r"device of `charges` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand Down Expand Up @@ -249,7 +248,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances():

def test_invalid_device_neighbor_indices():
calculator = CalculatorTest()
match = r"device of `neighbor_indices` \(meta\) must be same as `positions` \(cpu\)"
match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand All @@ -262,9 +261,7 @@ def test_invalid_device_neighbor_indices():

def test_invalid_device_neighbor_distances():
calculator = CalculatorTest()
match = (
r"device of `neighbor_distances` \(meta\) must be same as `positions` \(cpu\)"
)
match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)"
with pytest.raises(ValueError, match=match):
calculator.forward(
positions=POSITIONS_1,
Expand All @@ -279,7 +276,7 @@ def test_invalid_dtype_neighbor_distances():
calculator = CalculatorTest()
match = (
r"type of `neighbor_distances` \(torch.float64\) must be same "
r"as `positions` \(torch.float32\)"
r"as the class \(torch.float32\)"
)
with pytest.raises(TypeError, match=match):
calculator.forward(
Expand Down

0 comments on commit f9b2ef9

Please sign in to comment.