diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py index 53d6b964..6465a507 100644 --- a/examples/01-charges-example.py +++ b/examples/01-charges-example.py @@ -39,10 +39,11 @@ import torchpme from torchpme.tuning import tune_pme +dtype = torch.float64 + # %% # # Create the properties CsCl unit cell -dtype = torch.float64 symbols = ("Cs", "Cl") types = torch.tensor([55, 17]) charges = torch.tensor([[1.0], [-1.0]], dtype=dtype) diff --git a/src/torchpme/_utils.py b/src/torchpme/_utils.py index f282a626..8d42e45c 100644 --- a/src/torchpme/_utils.py +++ b/src/torchpme/_utils.py @@ -15,7 +15,7 @@ def _validate_parameters( ) -> None: if positions.dtype != dtype: raise TypeError( - f"type of `positions` ({positions.dtype}) must be same as class " + f"type of `positions` ({positions.dtype}) must be same as the class " f"type ({dtype})" ) @@ -24,10 +24,13 @@ def _validate_parameters( if positions.device.type != device: raise ValueError( - f"device of `positions` ({positions.device}) must be same as class " + f"device of `positions` ({positions.device}) must be same as the class " f"device ({device})" ) + # We use `positions.device` because it includes the device type AND index, which the + # `device` parameter may lack + # check shape, dtype and device of positions num_atoms = len(positions) if list(positions.shape) != [len(positions), 3]: @@ -45,12 +48,12 @@ def _validate_parameters( if cell.dtype != positions.dtype: raise TypeError( - f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})" + f"type of `cell` ({cell.dtype}) must be same as the class ({dtype})" ) if cell.device != positions.device: raise ValueError( - f"device of `cell` ({cell.device}) must be same as `positions` ({device})" + f"device of `cell` ({cell.device}) must be same as the class ({device})" ) if smearing is not None and torch.equal( @@ -79,12 +82,12 @@ def _validate_parameters( if charges.dtype != positions.dtype: raise TypeError( - f"type of `charges` ({charges.dtype}) must be same as `positions` ({dtype})" + f"type of `charges` ({charges.dtype}) must be same as the class ({dtype})" ) if charges.device != positions.device: raise ValueError( - f"device of `charges` ({charges.device}) must be same as `positions` " + f"device of `charges` ({charges.device}) must be same as the class " f"({device})" ) @@ -99,7 +102,7 @@ def _validate_parameters( if neighbor_indices.device != positions.device: raise ValueError( f"device of `neighbor_indices` ({neighbor_indices.device}) must be " - f"same as `positions` ({device})" + f"same as the class ({device})" ) if neighbor_distances.shape != neighbor_indices[:, 0].shape: @@ -112,11 +115,11 @@ def _validate_parameters( if neighbor_distances.device != positions.device: raise ValueError( f"device of `neighbor_distances` ({neighbor_distances.device}) must be " - f"same as `positions` ({device})" + f"same as the class ({device})" ) if neighbor_distances.dtype != positions.dtype: raise TypeError( f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same " - f"as `positions` ({dtype})" + f"as the class ({dtype})" ) diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index 567b0f15..dff72455 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -45,7 +45,7 @@ def test_compute_output_shapes(): def test_wrong_device_positions(): calculator = CalculatorTest() - match = r"device of `positions` \(meta\) must be same as class device \(cpu\)" + match = r"device of `positions` \(meta\) must be same as the class device \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1.to(device="meta"), @@ -59,7 +59,7 @@ def test_wrong_device_positions(): def test_wrong_dtype_positions(): calculator = CalculatorTest() match = ( - r"type of `positions` \(torch.float64\) must be same as class type " + r"type of `positions` \(torch.float64\) must be same as the class type " r"\(torch.float32\)" ) with pytest.raises(TypeError, match=match): @@ -108,8 +108,7 @@ def test_invalid_shape_cell(): def test_invalid_dtype_cell(): calculator = CalculatorTest() match = ( - r"type of `cell` \(torch.float64\) must be same as `positions` " - r"\(torch.float32\)" + r"type of `cell` \(torch.float64\) must be same as the class \(torch.float32\)" ) with pytest.raises(TypeError, match=match): calculator.forward( @@ -123,7 +122,7 @@ def test_invalid_dtype_cell(): def test_invalid_device_cell(): calculator = CalculatorTest() - match = r"device of `cell` \(meta\) must be same as `positions` \(cpu\)" + match = r"device of `cell` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -189,7 +188,7 @@ def test_invalid_shape_charges(): def test_invalid_dtype_charges(): calculator = CalculatorTest() match = ( - r"type of `charges` \(torch.float64\) must be same as `positions` " + r"type of `charges` \(torch.float64\) must be same as the class " r"\(torch.float32\)" ) with pytest.raises(TypeError, match=match): @@ -204,7 +203,7 @@ def test_invalid_dtype_charges(): def test_invalid_device_charges(): calculator = CalculatorTest() - match = r"device of `charges` \(meta\) must be same as `positions` \(cpu\)" + match = r"device of `charges` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -249,7 +248,7 @@ def test_invalid_shape_neighbor_indices_neighbor_distances(): def test_invalid_device_neighbor_indices(): calculator = CalculatorTest() - match = r"device of `neighbor_indices` \(meta\) must be same as `positions` \(cpu\)" + match = r"device of `neighbor_indices` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -262,9 +261,7 @@ def test_invalid_device_neighbor_indices(): def test_invalid_device_neighbor_distances(): calculator = CalculatorTest() - match = ( - r"device of `neighbor_distances` \(meta\) must be same as `positions` \(cpu\)" - ) + match = r"device of `neighbor_distances` \(meta\) must be same as the class \(cpu\)" with pytest.raises(ValueError, match=match): calculator.forward( positions=POSITIONS_1, @@ -279,7 +276,7 @@ def test_invalid_dtype_neighbor_distances(): calculator = CalculatorTest() match = ( r"type of `neighbor_distances` \(torch.float64\) must be same " - r"as `positions` \(torch.float32\)" + r"as the class \(torch.float32\)" ) with pytest.raises(TypeError, match=match): calculator.forward(