Skip to content

Commit

Permalink
Encode discriminator field first + clean up error messages (#1290)
Browse files Browse the repository at this point in the history
  • Loading branch information
plokhotnyuk authored Feb 6, 2025
1 parent 22c9a05 commit 049508a
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 193 deletions.
79 changes: 32 additions & 47 deletions zio-json/shared/src/main/scala-2.x/zio/json/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import magnolia1._
import zio.Chunk
import zio.json.JsonDecoder.JsonError
import zio.json.ast.Json
import zio.json.internal.{ FieldEncoder, Lexer, RetractReader, StringMatrix, Write }
import zio.json.internal.{ FieldEncoder, Lexer, RecordingReader, RetractReader, StringMatrix, Write }

import scala.annotation._
import scala.language.experimental.macros
Expand Down Expand Up @@ -88,10 +88,6 @@ object ziojson_03 {
final case class jsonMemberNames(format: JsonMemberFormat) extends Annotation
private[json] object jsonMemberNames {

/**
* ~~Stolen~~ Borrowed from jsoniter-scala by Andriy Plokhotnyuk (he even granted permission for this, imagine that!)
*/

import java.lang.Character._

def enforceCamelOrPascalCase(s: String, toPascal: Boolean): String =
Expand Down Expand Up @@ -231,9 +227,8 @@ object DeriveJsonDecoder {

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A =
json match {
case Json.Obj(_) => ctx.rawConstruct(Nil)
case Json.Null => ctx.rawConstruct(Nil)
case _ => Lexer.error("Not an object", trace)
case _: Json.Obj | Json.Null => ctx.rawConstruct(Nil)
case _ => Lexer.error("expected object", trace)
}
}
else
Expand Down Expand Up @@ -356,7 +351,6 @@ object DeriveJsonDecoder {
}
idx += 1
}

ctx.rawConstruct(new ArraySeq(ps))
}

Expand Down Expand Up @@ -387,7 +381,7 @@ object DeriveJsonDecoder {
idx += 1
}
ctx.rawConstruct(new ArraySeq(ps))
case _ => Lexer.error("Not an object", trace)
case _ => Lexer.error("expected object", trace)
}
}
}
Expand All @@ -412,13 +406,13 @@ object DeriveJsonDecoder {
def discrim =
ctx.annotations.collectFirst { case jsonDiscriminator(n) => n }.orElse(config.sumTypeHandling.discriminatorField)

if (discrim.isEmpty)
if (discrim.isEmpty) {
// We're not allowing extra fields in this encoding
new JsonDecoder[A] {
private[this] val spans = names.map(JsonError.ObjectAccess)

def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
Lexer.char(trace, in, '{')
// we're not allowing extra fields in this encoding
if (Lexer.firstField(trace, in)) {
val idx = Lexer.field(trace, in, matrix)
if (idx != -1) {
Expand All @@ -435,22 +429,21 @@ object DeriveJsonDecoder {
val keyValue = chunk.head
namesMap.get(keyValue._1) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, keyValue._2).asInstanceOf[A]
case _ => Lexer.error("Invalid disambiguator", trace)
case _ => Lexer.error("invalid disambiguator", trace)
}
case Json.Obj(_) => Lexer.error("Not an object with a single field", trace)
case _ => Lexer.error("Not an object", trace)
case _ => Lexer.error("expected single field object", trace)
}
}
else
} else {
new JsonDecoder[A] {
private[this] val hintfield = discrim.get
private[this] val hintmatrix = new StringMatrix(Array(hintfield))
private[this] val spans = names.map(JsonError.Message)

def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
val in_ = internal.RecordingReader(in)
val in_ = RecordingReader(in)
Lexer.char(trace, in_, '{')
if (Lexer.firstField(trace, in_))
if (Lexer.firstField(trace, in_)) {
do {
if (Lexer.field(trace, in_, hintmatrix) != -1) {
val idx = Lexer.enumeration(trace, in_, matrix)
Expand All @@ -460,6 +453,7 @@ object DeriveJsonDecoder {
} else Lexer.error("invalid disambiguator", trace)
} else Lexer.skipValue(trace, in_)
} while (Lexer.nextField(trace, in_))
}
Lexer.error(s"missing hint '$hintfield'", trace)
}

Expand All @@ -469,15 +463,15 @@ object DeriveJsonDecoder {
fields.find { case (key, _) => key == hintfield } match {
case Some((_, Json.Str(name))) =>
namesMap.get(name) match {
case Some(idx) => tcs(idx).unsafeFromJsonAST(trace, json).asInstanceOf[A]
case _ => Lexer.error("Invalid disambiguator", trace)
case Some(idx) => tcs(idx).unsafeFromJsonAST(spans(idx) :: trace, json).asInstanceOf[A]
case _ => Lexer.error("invalid disambiguator", trace)
}
case Some(_) => Lexer.error(s"Non-string hint '$hintfield'", trace)
case _ => Lexer.error(s"Missing hint '$hintfield'", trace)
case _ => Lexer.error(s"missing hint '$hintfield'", trace)
}
case _ => Lexer.error("Not an object", trace)
case _ => Lexer.error("expected object", trace)
}
}
}
}

def gen[A]: JsonDecoder[A] = macro Magnolia.gen[A]
Expand All @@ -489,13 +483,11 @@ object DeriveJsonEncoder {
def join[A](ctx: CaseClass[JsonEncoder, A])(implicit config: JsonCodecConfiguration): JsonEncoder[A] =
if (ctx.parameters.isEmpty)
new JsonEncoder[A] {

override def isEmpty(a: A): Boolean = true

def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = out.write("{}")

override final def toJsonAST(a: A): Either[String, Json] =
Right(Json.Obj(Chunk.empty))
override final def toJsonAST(a: A): Either[String, Json] = new Right(Json.Obj.empty)
}
else
new JsonEncoder[A] {
Expand Down Expand Up @@ -557,16 +549,14 @@ object DeriveJsonEncoder {
case _ => true
}
}) {
// if we have at least one field already, we need a comma
if (prevFields) {
out.write(',')
JsonEncoder.pad(indent_, out)
}
} else prevFields = true
JsonEncoder.string.unsafeEncode(field.name, indent_, out)
if (indent.isEmpty) out.write(':')
else out.write(" : ")
encoder.unsafeEncode(p, indent_, out)
prevFields = true // record that we have at least one field so far
}
idx += 1
}
Expand All @@ -578,7 +568,7 @@ object DeriveJsonEncoder {
fields
.foldLeft[Either[String, Chunk[(String, Json)]]](Right(Chunk.empty)) { case (c, field) =>
val param = field.p
val paramValue = field.p.dereference(a).asInstanceOf[param.PType]
val paramValue = param.dereference(a).asInstanceOf[param.PType]
field.encodeOrDefault(paramValue)(
() =>
c.flatMap { chunk =>
Expand All @@ -596,8 +586,7 @@ object DeriveJsonEncoder {
val names: Array[String] = ctx.subtypes.map { p =>
p.annotations.collectFirst { case jsonHint(name) => name }.getOrElse(jsonHintFormat(p.typeName.short))
}.toArray

def discrim =
val discrim =
ctx.annotations.collectFirst { case jsonDiscriminator(n) => n }.orElse(config.sumTypeHandling.discriminatorField)

if (discrim.isEmpty) {
Expand All @@ -614,16 +603,11 @@ object DeriveJsonEncoder {
out.write('}')
}

override def toJsonAST(a: A): Either[String, Json] =
ctx.split(a) { sub =>
sub.typeclass.toJsonAST(sub.cast(a)).map { inner =>
Json.Obj(
Chunk(
names(sub.index) -> inner
)
)
}
override def toJsonAST(a: A): Either[String, Json] = ctx.split(a) { sub =>
sub.typeclass.toJsonAST(sub.cast(a)).map { inner =>
Json.Obj(Chunk(names(sub.index) -> inner))
}
}
}
} else {
new JsonEncoder[A] {
Expand All @@ -642,13 +626,14 @@ object DeriveJsonEncoder {
sub.typeclass.unsafeEncode(sub.cast(a), indent, intermediate)
}

override def toJsonAST(a: A): Either[String, Json] =
ctx.split(a) { sub =>
sub.typeclass.toJsonAST(sub.cast(a)).flatMap {
case Json.Obj(fields) => Right(Json.Obj(fields :+ hintfield -> Json.Str(names(sub.index))))
case _ => Left("Subtype is not encoded as an object")
}
override def toJsonAST(a: A): Either[String, Json] = ctx.split(a) { sub =>
sub.typeclass.toJsonAST(sub.cast(a)).flatMap {
case Json.Obj(fields) =>
new Right(Json.Obj((hintfield -> Json.Str(names(sub.index))) +: fields)) // hint field is always first
case _ =>
new Left("expected object")
}
}
}
}
}
Expand Down
Loading

0 comments on commit 049508a

Please sign in to comment.