From 20064c370e51e546e77e9cf6d3aff6afb819e917 Mon Sep 17 00:00:00 2001 From: Yannis Mentekidis Date: Wed, 10 Apr 2024 18:17:37 -0400 Subject: [PATCH 1/4] Fix flaky KLL test --- src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala | 5 +++++ src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala | 5 +++-- .../scala/com/amazon/deequ/analyzers/ColumnCountTest.scala | 4 ++++ 3 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala create mode 100644 src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala diff --git a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala new file mode 100644 index 000000000..07a70d1e1 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala @@ -0,0 +1,5 @@ +package com.amazon.deequ.analyzers + +class ColumnCount { + +} diff --git a/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala b/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala index 20017fa71..728ce866c 100644 --- a/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala +++ b/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala @@ -22,7 +22,8 @@ import com.amazon.deequ.analyzers.{Distance, QuantileNonSample} import com.amazon.deequ.metrics.BucketValue import com.amazon.deequ.utils.FixtureSupport import org.scalatest.WordSpec -import com.amazon.deequ.metrics.{BucketValue} +import com.amazon.deequ.metrics.BucketValue +import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper class KLLDistanceTest extends WordSpec with SparkContextSpec with FixtureSupport{ @@ -88,7 +89,7 @@ class KLLDistanceTest extends WordSpec with SparkContextSpec val sample2 = scala.collection.mutable.Map( "a" -> 22L, "b" -> 20L, "c" -> 25L, "d" -> 12L, "e" -> 13L, "f" -> 15L) val distance = Distance.categoricalDistance(sample1, sample2, method = LInfinityMethod(alpha = Some(0.003))) - assert(distance == 0.2726338046550349) + assert(distance === 0.2726338046550349 +- 1E-14) } "Categorial distance should compute correct linf_robust with different alpha value .1" in { diff --git a/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala new file mode 100644 index 000000000..cee769b9a --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala @@ -0,0 +1,4 @@ +import junit.framework.TestCase; +public class ColumnCountTest extends TestCase { + +} \ No newline at end of file From 5feaaaf377987201d7213d3b5a8b926b31fff6d0 Mon Sep 17 00:00:00 2001 From: Yannis Mentekidis Date: Wed, 10 Apr 2024 18:18:15 -0400 Subject: [PATCH 2/4] Move CustomSql state to CustomSql analyzer --- .../scala/com/amazon/deequ/analyzers/CustomSql.scala | 11 +++++++++++ src/main/scala/com/amazon/deequ/analyzers/Size.scala | 11 ----------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index e07e2d11f..edd4f8e97 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -26,6 +26,17 @@ import scala.util.Failure import scala.util.Success import scala.util.Try +case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { + lazy val state = stateOrError.left.get + lazy val error = stateOrError.right.get + + override def sum(other: CustomSqlState): CustomSqlState = { + CustomSqlState(Left(state + other.state)) + } + + override def metricValue(): Double = state +} + case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { /** * Compute the state (sufficient statistics) from the data diff --git a/src/main/scala/com/amazon/deequ/analyzers/Size.scala b/src/main/scala/com/amazon/deequ/analyzers/Size.scala index a5080084a..c56083abe 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Size.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Size.scala @@ -20,17 +20,6 @@ import com.amazon.deequ.metrics.Entity import org.apache.spark.sql.{Column, Row} import Analyzers._ -case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { - lazy val state = stateOrError.left.get - lazy val error = stateOrError.right.get - - override def sum(other: CustomSqlState): CustomSqlState = { - CustomSqlState(Left(state + other.state)) - } - - override def metricValue(): Double = state -} - case class NumMatches(numMatches: Long) extends DoubleValuedState[NumMatches] { override def sum(other: NumMatches): NumMatches = { From 5d5d8d49d636d03e568fd78dde9721c9c8ae6525 Mon Sep 17 00:00:00 2001 From: Yannis Mentekidis Date: Wed, 10 Apr 2024 18:18:29 -0400 Subject: [PATCH 3/4] Implement new Analyzer to count columns --- .../amazon/deequ/analyzers/ColumnCount.scala | 63 ++++++++++++++++++- .../scala/com/amazon/deequ/checks/Check.scala | 7 +++ .../amazon/deequ/constraints/Constraint.scala | 12 ++++ .../amazon/deequ/VerificationSuiteTest.scala | 1 + .../deequ/analyzers/ColumnCountTest.scala | 49 +++++++++++++-- 5 files changed, 127 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala index 07a70d1e1..53ad82300 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala @@ -1,5 +1,66 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + package com.amazon.deequ.analyzers -class ColumnCount { +import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame + +case class ColumnCount(where: Option[String] = None) extends Analyzer[NumMatches, DoubleMetric] { + + val name = "ColumnCount" + val instance = "*" + val entity = Entity.Dataset + + + /** + * Compute the state (sufficient statistics) from the data + * + * @param data data frame + * @return + */ + override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = { + if (filterCondition.isDefined) { + throw new IllegalArgumentException("ColumnCount does not accept a filter condition") + } else { + val numColumns = data.columns.size + Some(NumMatches(numColumns)) + } + } + + /** + * Compute the metric from the state (sufficient statistics) + * + * @param state wrapper holding a state of type S (required due to typing issues...) + * @return + */ + override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { + if (state.isDefined) { + Analyzers.metricFromValue(state.get.metricValue(), name, instance, entity) + } else { + Analyzers.metricFromEmpty(this, name, instance, entity) + } + } + /** + * Compute the metric from a failure - reports the exception thrown while trying to count columns + */ + override private[deequ] def toFailureMetric(failure: Exception): DoubleMetric = { + Analyzers.metricFromFailure(failure, name, instance, entity) + } } diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index 2537922be..1e1048921 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -127,6 +127,13 @@ case class Check( addFilterableConstraint { filter => Constraint.sizeConstraint(assertion, filter, hint) } } + def hasColumnCount(assertion: Long => Boolean, hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + addFilterableConstraint { + filter => Constraint.columnCountConstraint(assertion, hint) + } + } + /** * Creates a constraint that asserts on a column completion. * diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index c0e6e9b9d..e289b3859 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -128,6 +128,18 @@ object Constraint { new NamedConstraint(constraint, s"SizeConstraint($size)") } + def columnCountConstraint(assertion: Long => Boolean, hint: Option[String] = None): Constraint = { + val colCount = ColumnCount() + fromAnalyzer(colCount, assertion, hint) + } + + + def fromAnalyzer(colCount: ColumnCount, assertion: Long => Boolean, hint: Option[String]): Constraint = { + val constraint = AnalysisBasedConstraint[NumMatches, Double, Long](colCount, assertion, Some(_.toLong), hint) + + new NamedConstraint(constraint, name = s"ColumnCountConstraint($colCount)") + } + /** * Runs Histogram analysis on the given column and executes the assertion * diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index df13ea901..146579e8e 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -61,6 +61,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val checkToSucceed = Check(CheckLevel.Error, "group-1") .isComplete("att1") + .hasColumnCount(_ == 3) .hasCompleteness("att1", _ == 1.0) val checkToErrorOut = Check(CheckLevel.Error, "group-2-E") diff --git a/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala index cee769b9a..00df2758c 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala @@ -1,4 +1,45 @@ -import junit.framework.TestCase; -public class ColumnCountTest extends TestCase { - -} \ No newline at end of file +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success + +class ColumnCountTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + "ColumnCount" should { + "return column count for a dataset" in withSparkSession { session => + val data = getDfWithStringColumns(session) + val colCount = ColumnCount() + + val state = colCount.computeStateFrom(data) + state.isDefined shouldBe true + state.get.metricValue() shouldBe 5.0 + + val metric = colCount.computeMetricFrom(state) + metric.fullColumn shouldBe None + metric.value shouldBe Success(5.0) + } + } +} From 5c91225ec62c1d53071c6b9aae5c560101945687 Mon Sep 17 00:00:00 2001 From: Yannis Mentekidis Date: Mon, 15 Apr 2024 13:57:26 -0400 Subject: [PATCH 4/4] Improve documentation, remove unused parameter, replace if/else with map --- .../amazon/deequ/analyzers/ColumnCount.scala | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala index 53ad82300..9eff89b6d 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala @@ -21,18 +21,17 @@ import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.Entity import org.apache.spark.sql.DataFrame -case class ColumnCount(where: Option[String] = None) extends Analyzer[NumMatches, DoubleMetric] { +case class ColumnCount() extends Analyzer[NumMatches, DoubleMetric] { val name = "ColumnCount" val instance = "*" val entity = Entity.Dataset - /** * Compute the state (sufficient statistics) from the data * - * @param data data frame - * @return + * @param data the input dataframe + * @return the number of columns in the input */ override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = { if (filterCondition.isDefined) { @@ -46,15 +45,13 @@ case class ColumnCount(where: Option[String] = None) extends Analyzer[NumMatches /** * Compute the metric from the state (sufficient statistics) * - * @param state wrapper holding a state of type S (required due to typing issues...) - * @return + * @param state the computed state from [[computeStateFrom]] + * @return a double metric indicating the number of columns for this analyzer */ override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { - if (state.isDefined) { - Analyzers.metricFromValue(state.get.metricValue(), name, instance, entity) - } else { - Analyzers.metricFromEmpty(this, name, instance, entity) - } + state + .map(v => Analyzers.metricFromValue(v.metricValue(), name, instance, entity)) + .getOrElse(Analyzers.metricFromEmpty(this, name, instance, entity)) } /**