Skip to content

Commit

Permalink
Add a compile-time option to toggle serialization of enum values and …
Browse files Browse the repository at this point in the history
…sealed trait's case objects as JSON strings or JSON objects
  • Loading branch information
plokhotnyuk committed Feb 24, 2025
1 parent 80adab8 commit 571125a
Show file tree
Hide file tree
Showing 15 changed files with 721 additions and 314 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package zio.json

import zio.json.JsonCodecConfiguration.SumTypeHandling
import zio.json.JsonCodecConfiguration.SumTypeHandling.WrapperWithClassNameField

/**
* When disabled for decoding, keys with empty collections will be omitted from the JSON. When disabled for encoding,
* missing keys will default to empty collections.
*/
case class ExplicitEmptyCollections(encoding: Boolean = true, decoding: Boolean = true)

/**
* Implicit codec derivation configuration.
*
* @param sumTypeHandling
* see [[jsonDiscriminator]]
* @param fieldNameMapping
* see [[jsonMemberNames]]
* @param allowExtraFields
* see [[jsonNoExtraFields]]
* @param sumTypeMapping
* see [[jsonHintNames]]
* @param explicitNulls
* turns on explicit serialization of optional fields with None values
* @param explicitEmptyCollections
* turns on explicit serialization of fields with empty collections
* @param enumValuesAsStrings
* turns on serialization of enum values and sealed trait's case objects as strings
*/
final case class JsonCodecConfiguration(
sumTypeHandling: SumTypeHandling = WrapperWithClassNameField,
fieldNameMapping: JsonMemberFormat = IdentityFormat,
allowExtraFields: Boolean = true,
sumTypeMapping: JsonMemberFormat = IdentityFormat,
explicitNulls: Boolean = false,
explicitEmptyCollections: ExplicitEmptyCollections = ExplicitEmptyCollections(),
enumValuesAsStrings: Boolean = false
) {
def this(
sumTypeHandling: SumTypeHandling,
fieldNameMapping: JsonMemberFormat,
allowExtraFields: Boolean,
sumTypeMapping: JsonMemberFormat,
explicitNulls: Boolean,
explicitEmptyCollections: ExplicitEmptyCollections
) = this(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
explicitEmptyCollections,
false
)

def this(
sumTypeHandling: SumTypeHandling,
fieldNameMapping: JsonMemberFormat,
allowExtraFields: Boolean,
sumTypeMapping: JsonMemberFormat,
explicitNulls: Boolean
) = this(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
ExplicitEmptyCollections(),
false
)

def copy(
sumTypeHandling: SumTypeHandling = WrapperWithClassNameField.asInstanceOf[SumTypeHandling],
fieldNameMapping: JsonMemberFormat = IdentityFormat.asInstanceOf[JsonMemberFormat],
allowExtraFields: Boolean = true,
sumTypeMapping: JsonMemberFormat = IdentityFormat.asInstanceOf[JsonMemberFormat],
explicitNulls: Boolean = false,
explicitEmptyCollections: ExplicitEmptyCollections = ExplicitEmptyCollections(),
enumValuesAsStrings: Boolean = false
) = new JsonCodecConfiguration(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
explicitEmptyCollections,
enumValuesAsStrings
)

def copy(
sumTypeHandling: SumTypeHandling,
fieldNameMapping: JsonMemberFormat,
allowExtraFields: Boolean,
sumTypeMapping: JsonMemberFormat,
explicitNulls: Boolean,
explicitEmptyCollections: ExplicitEmptyCollections
) = new JsonCodecConfiguration(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
explicitEmptyCollections,
this.enumValuesAsStrings
)

def copy(
sumTypeHandling: SumTypeHandling,
fieldNameMapping: JsonMemberFormat,
allowExtraFields: Boolean,
sumTypeMapping: JsonMemberFormat,
explicitNulls: Boolean
) = new JsonCodecConfiguration(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
this.explicitEmptyCollections,
this.enumValuesAsStrings
)
}

object JsonCodecConfiguration {
def apply(
sumTypeHandling: SumTypeHandling,
fieldNameMapping: JsonMemberFormat,
allowExtraFields: Boolean,
sumTypeMapping: JsonMemberFormat,
explicitNulls: Boolean,
explicitEmptyCollections: ExplicitEmptyCollections
) = new JsonCodecConfiguration(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
explicitEmptyCollections,
false
)

def apply(
sumTypeHandling: SumTypeHandling,
fieldNameMapping: JsonMemberFormat,
allowExtraFields: Boolean,
sumTypeMapping: JsonMemberFormat,
explicitNulls: Boolean
) = new JsonCodecConfiguration(
sumTypeHandling,
fieldNameMapping,
allowExtraFields,
sumTypeMapping,
explicitNulls,
ExplicitEmptyCollections(),
false
)

implicit val default: JsonCodecConfiguration = JsonCodecConfiguration()

sealed trait SumTypeHandling {
def discriminatorField: Option[String]
}

object SumTypeHandling {

/**
* Use an object with a single key that is the class name.
*/
case object WrapperWithClassNameField extends SumTypeHandling {
override def discriminatorField: Option[String] = None
}

/**
* For sealed classes, will determine the name of the field for disambiguating classes.
*
* The default is to not use a typehint field and instead have an object with a single key that is the class name.
* See [[WrapperWithClassNameField]].
*
* Note that using a discriminator is less performant, uses more memory, and may be prone to DOS attacks that are
* impossible with the default encoding. In addition, there is slightly less type safety when using custom product
* encoders (which must write an unenforced object type). Only use this option if you must model an externally
* defined schema.
*/
final case class DiscriminatorField(name: String) extends SumTypeHandling {
override def discriminatorField: Option[String] = Some(name)
}
}
}
94 changes: 65 additions & 29 deletions zio-json/shared/src/main/scala-2.x/zio/json/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ final class jsonNoExtraFields extends Annotation
*/
final class jsonExclude extends Annotation

private class CaseObjectDecoder[Typeclass[_], A](val ctx: CaseClass[Typeclass, A], no_extra: Boolean)
extends CollectionJsonDecoder[A] {
def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
if (no_extra) {
Lexer.char(trace, in, '{')
Lexer.char(trace, in, '}')
} else Lexer.skipValue(trace, in)
ctx.rawConstruct(Nil)
}

override def unsafeDecodeMissing(trace: List[JsonError]): A = ctx.rawConstruct(Nil)

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case _: Json.Obj | Json.Null => ctx.rawConstruct(Nil)
case _ => Lexer.error("expected object", trace)
}
}

object DeriveJsonDecoder {
type Typeclass[A] = JsonDecoder[A]

Expand All @@ -212,25 +231,7 @@ object DeriveJsonDecoder {
}.isDefined || !config.allowExtraFields

if (ctx.parameters.isEmpty)
new CollectionJsonDecoder[A] {
override def unsafeDecodeMissing(trace: List[JsonError]): A = ctx.rawConstruct(Nil)

def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
if (no_extra) {
Lexer.char(trace, in, '{')
Lexer.char(trace, in, '}')
} else {
Lexer.skipValue(trace, in)
}
ctx.rawConstruct(Nil)
}

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case _: Json.Obj | Json.Null => ctx.rawConstruct(Nil)
case _ => Lexer.error("expected object", trace)
}
}
new CaseObjectDecoder(ctx, no_extra)
else
new CollectionJsonDecoder[A] {
private[this] val (names, aliases): (Array[String], Array[(String, Int)]) = {
Expand Down Expand Up @@ -403,10 +404,31 @@ object DeriveJsonDecoder {
lazy val tcs = ctx.subtypes.map(_.typeclass).toArray.asInstanceOf[Array[JsonDecoder[Any]]]
lazy val namesMap = names.zipWithIndex.toMap

def discrim =
val isEnumeration = config.enumValuesAsStrings &&
ctx.subtypes.forall(_.typeclass.isInstanceOf[CaseObjectDecoder[JsonDecoder, _]])

val discrim =
ctx.annotations.collectFirst { case jsonDiscriminator(n) => n }.orElse(config.sumTypeHandling.discriminatorField)

if (discrim.isEmpty) {
if (isEnumeration && discrim.isEmpty) {
new JsonDecoder[A] {
def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
val idx = Lexer.enumeration(trace, in, matrix)
if (idx != -1) tcs(idx).asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
else Lexer.error("invalid enumeration value", trace)
}

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case s: Json.Str =>
namesMap.get(s.value) match {
case Some(idx) => tcs(idx).asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
case _ => Lexer.error("invalid enumeration value", trace)
}
case _ => Lexer.error("expected string", trace)
}
}
} else if (discrim.isEmpty) {
// We're not allowing extra fields in this encoding
new JsonDecoder[A] {
private[this] val spans = names.map(JsonError.ObjectAccess)
Expand Down Expand Up @@ -481,17 +503,19 @@ object DeriveJsonDecoder {
}

object DeriveJsonEncoder {
private lazy val caseObjectEncoder = new JsonEncoder[Any] {
override def isEmpty(a: Any): Boolean = true

def unsafeEncode(a: Any, indent: Option[Int], out: Write): Unit = out.write("{}")

override final def toJsonAST(a: Any): Either[String, Json] = new Right(Json.Obj.empty)
}

type Typeclass[A] = JsonEncoder[A]

def join[A](ctx: CaseClass[JsonEncoder, A])(implicit config: JsonCodecConfiguration): JsonEncoder[A] =
if (ctx.parameters.isEmpty)
new JsonEncoder[A] {
override def isEmpty(a: A): Boolean = true

def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = out.write("{}")

override final def toJsonAST(a: A): Either[String, Json] = new Right(Json.Obj.empty)
}
caseObjectEncoder.narrow[A]
else
new JsonEncoder[A] {
private[this] val (transformNames, nameTransform): (Boolean, String => String) =
Expand Down Expand Up @@ -584,6 +608,8 @@ object DeriveJsonEncoder {
}

def split[A](ctx: SealedTrait[JsonEncoder, A])(implicit config: JsonCodecConfiguration): JsonEncoder[A] = {
val isEnumeration = config.enumValuesAsStrings &&
ctx.subtypes.forall(_.typeclass == caseObjectEncoder)
val jsonHintFormat: JsonMemberFormat =
ctx.annotations.collectFirst { case jsonHintNames(format) => format }.getOrElse(config.sumTypeMapping)
val names: Array[String] = ctx.subtypes.map { p =>
Expand All @@ -592,7 +618,17 @@ object DeriveJsonEncoder {
val discrim =
ctx.annotations.collectFirst { case jsonDiscriminator(n) => n }.orElse(config.sumTypeHandling.discriminatorField)

if (discrim.isEmpty) {
if (isEnumeration && discrim.isEmpty) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = ctx.split(a) { sub =>
JsonEncoder.string.unsafeEncode(names(sub.index), indent, out)
}

override final def toJsonAST(a: A): Either[String, Json] = ctx.split(a) { sub =>
new Right(new Json.Str(names(sub.index)))
}
}
} else if (discrim.isEmpty) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = ctx.split(a) { sub =>
out.write('{')
Expand Down
Loading

0 comments on commit 571125a

Please sign in to comment.