Skip to content

Commit

Permalink
Fixing StandardScalar tests
Browse files Browse the repository at this point in the history
  • Loading branch information
naoise-h committed Feb 6, 2024
1 parent 29cd87c commit d369f63
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tests/models/test_StandardScaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sklearn.preprocessing as sk_pp

from diffprivlib.models.standard_scaler import StandardScaler
from diffprivlib.utils import PrivacyLeakWarning, DiffprivlibCompatibilityWarning, BudgetError
from diffprivlib.utils import PrivacyLeakWarning, DiffprivlibCompatibilityWarning, BudgetError, check_random_state


class TestStandardScaler(TestCase):
Expand Down Expand Up @@ -65,12 +65,13 @@ def test_inf_epsilon(self):
self.assertTrue(np.all(dp_ss.n_samples_seen_ == sk_ss.n_samples_seen_))

def test_different_results(self):
X = np.random.rand(10, 5)
rng = check_random_state(1)
X = rng.random((10, 5))

ss1 = StandardScaler(bounds=(0, 1))
ss1 = StandardScaler(bounds=(0, 1), random_state=rng)
ss1.fit(X)

ss2 = StandardScaler(bounds=(0, 1))
ss2 = StandardScaler(bounds=(0, 1), random_state=rng)
ss2.fit(X)

self.assertFalse(np.allclose(ss1.mean_, ss2.mean_), "Arrays %s and %s should be different" %
Expand All @@ -88,8 +89,8 @@ def test_functionality(self):
self.assertIsNotNone(ss.fit_transform(X))

def test_similar_results(self):
rng = np.random.RandomState(0)
X = rng.rand(100000, 5)
rng = check_random_state(0)
X = rng.random((100000, 5))

dp_ss = StandardScaler(bounds=(0, 1), epsilon=float("inf"), random_state=rng)
dp_ss.fit(X)
Expand All @@ -104,8 +105,8 @@ def test_similar_results(self):
self.assertTrue(np.all(dp_ss.n_samples_seen_ == sk_ss.n_samples_seen_))

def test_random_state(self):
rng = np.random.RandomState(0)
X = rng.rand(100000, 5)
rng = check_random_state(0)
X = rng.random((100000, 5))

ss0 = StandardScaler(bounds=(0, 1), epsilon=1, random_state=0)
ss1 = StandardScaler(bounds=(0, 1), epsilon=1, random_state=1)
Expand Down

0 comments on commit d369f63

Please sign in to comment.