From 8cbffcdb76bd15c94434c0c2faffd4212a379658 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Tue, 7 May 2024 23:18:15 +1000 Subject: [PATCH] Add Separate Wilson and Wald Interval Test --- .../rules/ConstraintRulesTest.scala | 92 +++++++++++-------- .../rules/interval/IntervalStrategyTest.scala | 41 +++++---- 2 files changed, 80 insertions(+), 53 deletions(-) diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala index 701a5d98..328691fd 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala @@ -22,10 +22,13 @@ import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.constraints.ConstrainableDataTypes import com.amazon.deequ.metrics.{Distribution, DistributionValue} import com.amazon.deequ.profiles._ +import com.amazon.deequ.suggestions.rules.interval.{WaldIntervalStrategy, WilsonScoreIntervalStrategy} import com.amazon.deequ.utils.FixtureSupport import com.amazon.deequ.{SparkContextSpec, VerificationSuite} import org.scalamock.scalatest.MockFactory +import org.scalatest.Inspectors.forAll import org.scalatest.WordSpec +import org.scalatest.prop.Tables.Table class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContextSpec with MockFactory{ @@ -132,6 +135,7 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext val complete = StandardColumnProfile("col1", 1.0, 100, String, false, Map.empty, None) val tenPercent = StandardColumnProfile("col1", 0.1, 100, String, false, Map.empty, None) val incomplete = StandardColumnProfile("col1", .25, 100, String, false, Map.empty, None) + val waldIntervalStrategy = WaldIntervalStrategy() assert(!RetainCompletenessRule().shouldBeApplied(complete, 1000)) assert(!RetainCompletenessRule(0.05, 0.9).shouldBeApplied(complete, 1000)) @@ -139,74 +143,90 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext assert(RetainCompletenessRule(0.0).shouldBeApplied(tenPercent, 1000)) assert(RetainCompletenessRule(0.0).shouldBeApplied(incomplete, 1000)) assert(RetainCompletenessRule().shouldBeApplied(incomplete, 1000)) + assert(!RetainCompletenessRule(intervalStrategy = waldIntervalStrategy).shouldBeApplied(complete, 1000)) + assert(!RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(complete, 1000)) + assert(RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(tenPercent, 1000)) } "return evaluable constraint candidates" in withSparkSession { session => + val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true)) + forAll(table) { case (strategy, result) => + val dfWithColumnCandidate = getDfFull(session) - val dfWithColumnCandidate = getDfFull(session) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val check = Check(CheckLevel.Warning, "some") + .addConstraint(RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100).constraint) - val check = Check(CheckLevel.Warning, "some") - .addConstraint(RetainCompletenessRule().candidate(fakeColumnProfile, 100).constraint) + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val metricResult = verificationResult.metrics.head._2 - val metricResult = verificationResult.metrics.head._2 + assert(metricResult.value.isSuccess == result) + } - assert(metricResult.value.isSuccess) } "return working code to add constraint to check" in withSparkSession { session => + val table = Table( + ("strategy", "colCompleteness", "targetCompleteness", "result"), + (WaldIntervalStrategy(), 0.5, 0.4, true), + (WilsonScoreIntervalStrategy(), 0.4, 0.3, true) + ) + forAll(table) { case (strategy, colCompleteness, targetCompleteness, result) => - val dfWithColumnCandidate = getDfFull(session) + val dfWithColumnCandidate = getDfFull(session) - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", colCompleteness) - val codeForConstraint = RetainCompletenessRule().candidate(fakeColumnProfile, 100) - .codeForConstraint + val codeForConstraint = RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100) + .codeForConstraint - val expectedCodeForConstraint = """.hasCompleteness("att1", _ >= 0.4, - | Some("It should be above 0.4!"))""".stripMargin.replaceAll("\n", "") + val expectedCodeForConstraint = s""".hasCompleteness("att1", _ >= $targetCompleteness, + | Some("It should be above $targetCompleteness!"))""".stripMargin.replaceAll("\n", "") - assert(expectedCodeForConstraint == codeForConstraint) + assert(expectedCodeForConstraint == codeForConstraint) - val check = Check(CheckLevel.Warning, "some") - .hasCompleteness("att1", _ >= 0.4, Some("It should be above 0.4!")) + val check = Check(CheckLevel.Warning, "some") + .hasCompleteness("att1", _ >= targetCompleteness, Some(s"It should be above $targetCompleteness")) - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val metricResult = verificationResult.metrics.head._2 + val metricResult = verificationResult.metrics.head._2 + + assert(metricResult.value.isSuccess == result) + } - assert(metricResult.value.isSuccess) } "return evaluable constraint candidates with custom min/max completeness" in withSparkSession { session => + val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true)) + forAll(table) { case (strategy, result) => + val dfWithColumnCandidate = getDfFull(session) - val dfWithColumnCandidate = getDfFull(session) - - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) - val check = Check(CheckLevel.Warning, "some") - .addConstraint(RetainCompletenessRule(0.4, 0.6).candidate(fakeColumnProfile, 100).constraint) + val check = Check(CheckLevel.Warning, "some") + .addConstraint(RetainCompletenessRule(0.4, 0.6, strategy).candidate(fakeColumnProfile, 100).constraint) - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val metricResult = verificationResult.metrics.head._2 + val metricResult = verificationResult.metrics.head._2 - assert(metricResult.value.isSuccess) + assert(metricResult.value.isSuccess == result) + } } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala index 708fd285..2d102179 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala @@ -4,29 +4,36 @@ import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval import com.amazon.deequ.utils.FixtureSupport import org.scalamock.scalatest.MockFactory +import org.scalatest.Inspectors.forAll +import org.scalatest.prop.Tables.Table import org.scalatest.wordspec.AnyWordSpec class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec with MockFactory { - "WaldIntervalStrategy" should { + "ConfidenceIntervalStrategy" should { "be calculated correctly" in { - assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(1.0, 1.0)) - assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6)) - assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5)) - assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7)) - assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.84, 0.96)) - assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(1.0, 1.0)) - } - } + val waldStrategy = WaldIntervalStrategy() + val wilsonStrategy = WilsonScoreIntervalStrategy() + val table = Table( + ("strategy", "pHat", "numRecord", "lowerBound", "upperBound"), + (waldStrategy, 1.0, 20L, 1.0, 1.0), + (waldStrategy, 0.5, 100L, 0.4, 0.6), + (waldStrategy, 0.4, 100L, 0.3, 0.5), + (waldStrategy, 0.6, 100L, 0.5, 0.7), + (waldStrategy, 0.9, 100L, 0.84, 0.96), + (waldStrategy, 1.0, 100L, 1.0, 1.0), - "WilsonIntervalStrategy" should { - "be calculated correctly" in { - assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(0.83, 1.0)) - assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6)) - assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5)) - assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7)) - assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.82, 0.95)) - assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(0.96, 1.0)) + (wilsonStrategy, 0.01, 20L, 0.00, 0.18), + (wilsonStrategy, 1.0, 20L, 0.83, 1.0), + (wilsonStrategy, 0.5, 100L, 0.4, 0.6), + (wilsonStrategy, 0.4, 100L, 0.3, 0.5), + (wilsonStrategy, 0.6, 100L, 0.5, 0.7), + (wilsonStrategy, 0.9, 100L, 0.82, 0.95), + (wilsonStrategy, 1.0, 100L, 0.96, 1.0), + ) + forAll(table) { case (strategy, pHat, numRecords, lowerBound, upperBound) => + assert(strategy.calculateTargetConfidenceInterval(pHat, numRecords) == ConfidenceInterval(lowerBound, upperBound)) + } } } }