From 320da78a0c2b6784f269f60737b85b27ecf8a5e9 Mon Sep 17 00:00:00 2001 From: Gregor Ihmor Date: Sat, 16 Nov 2024 11:56:16 +0100 Subject: [PATCH] Add tests for required feature --- .../scalapb/compiler/ParseFromGenerator.scala | 22 ++++---- e2e/src/test/scala/RequiredFieldsSpec.scala | 51 ++++++++++++++++++- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala b/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala index 9a8b78984..21ab611bf 100644 --- a/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala +++ b/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala @@ -110,8 +110,11 @@ private[compiler] class ParseFromGenerator( private def usesBaseTypeInBuilder(field: FieldDescriptor) = field.isSingular - val requiredFieldMap: Map[FieldDescriptor, Int] = - message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex.toMap + private val requiredFields: Seq[(FieldDescriptor, Int)] = + message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex + + private val requiredFieldMap: Map[FieldDescriptor, Int] = + requiredFields.toMap val myFullScalaName = message.scalaType.fullNameWithMaybeRoot(message) @@ -231,16 +234,15 @@ private[compiler] class ParseFromGenerator( p.add(s"""if (${r}) {""") .indent .add("val __missingFields = Seq.newBuilder[_root_.scala.Predef.String]") - .print(requiredFieldMap.toSeq.sortBy(_._2)) { - case (p, (fieldDescriptor, fieldNumber)) => - val bitmask = s"0x${"%x".format(1L << fieldNumber)}L" - val fieldVariable = s"__requiredFields${fieldNumber / 64}" - p.add( - s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}"""" - ) + .print(requiredFields) { case (p, (fieldDescriptor, fieldNumber)) => + val bitmask = f"${1L << fieldNumber}%#018xL" + val fieldVariable = s"__requiredFields${fieldNumber / 64}" + p.add( + s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}"""" + ) } .add( - s"""val __message = s"Message missing required fields: $${__missingFields.result.mkString(", ")}"""", + s"""val __message = s"Message missing required fields: $${__missingFields.result().mkString(", ")}"""", s"""throw new _root_.com.google.protobuf.InvalidProtocolBufferException(__message)""" ) .outdent diff --git a/e2e/src/test/scala/RequiredFieldsSpec.scala b/e2e/src/test/scala/RequiredFieldsSpec.scala index 13dbf6b8a..575076b66 100644 --- a/e2e/src/test/scala/RequiredFieldsSpec.scala +++ b/e2e/src/test/scala/RequiredFieldsSpec.scala @@ -1,11 +1,60 @@ import com.google.protobuf.InvalidProtocolBufferException import com.thesamet.proto.e2e.reqs.RequiredFields +import protobuf_unittest.unittest.TestEmptyMessage +import scalapb.UnknownFieldSet import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.must.Matchers class RequiredFieldsSpec extends AnyFlatSpec with Matchers { + + private val descriptor = RequiredFields.javaDescriptor + + private def partialMessage(fields: Map[Int, Int]): Array[Byte] = { + val fieldSet = fields.foldLeft(UnknownFieldSet.empty){ case (fieldSet, (field, value)) => + fieldSet + .withField(field, UnknownFieldSet.Field(varint = Seq(value))) + } + + TestEmptyMessage(fieldSet).toByteArray + } + + private val allFieldsSet: Map[Int, Int] = (100 to 164).map(i => (i, i)).toMap + "RequiredMessage" should "throw InvalidProtocolBufferException for empty byte array" in { - intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]())) + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]())) + + exception.getMessage() must startWith("Message missing required fields") + } + + it should "throw no exception when all fields are set correctly" in { + val parsed = RequiredFields.parseFrom(partialMessage(allFieldsSet)) + parsed must be(a[RequiredFields]) + parsed.f0 must be(100) + parsed.f64 must be(164) + } + + it should "throw an exeption if a field is missing and name the missing field" in { + val fields = allFieldsSet.removed(123) + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) + + exception.getMessage() must be("Message missing required fields: f23") + } + + it should "throw an exeption if a multiple fields are missing and name those missing fields" in { + val fields = allFieldsSet.removed(123).removed(164).removed(130) + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) + + exception.getMessage() must be("Message missing required fields: f23, f30, f64") + } + + it should "sort the missing fields by field number" in { + val fields = Map.empty[Int, Int] + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) + val missingFields =exception.getMessage().stripPrefix("Message missing required fields: ").split(", ") + + missingFields.sortBy[Int](field => descriptor.findFieldByName(field).getNumber()) must be(missingFields) + + missingFields.toSeq mustBe Seq.tabulate(65)(i => s"f$i") } }