diff --git a/tests/test_potentials.py b/tests/test_potentials.py index 2898d65a..f1458b72 100644 --- a/tests/test_potentials.py +++ b/tests/test_potentials.py @@ -587,7 +587,7 @@ def test_potential_device(potential_class, device): exponent = 1.0 dtype = torch.float64 - if potential_class == InversePowerLawPotential: + if potential_class is InversePowerLawPotential: potential = potential_class( exponent=exponent, smearing=smearing, dtype=dtype, device=device )