Skip to content

Commit

Permalink
Add derivation of Schema for union types (closes ghostdogpr#1926)
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindberg committed May 7, 2024
1 parent ab7dc85 commit 08202b8
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
2 changes: 2 additions & 0 deletions core/src/main/scala-3/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ trait SchemaDerivation[R] extends CommonSchemaDerivation {

inline def genDebug[R, A]: Schema[R, A] = PrintDerived(derived[R, A])

inline def unionType[T]: Schema[R, T] = ${ TypeUnionDerivation.typeUnionSchema[R, T] }

final lazy val auto = new AutoSchemaDerivation[Any] {}

final class SemiAuto[A](impl: Schema[R, A]) extends Schema[R, A] {
Expand Down
76 changes: 76 additions & 0 deletions core/src/main/scala-3/caliban/schema/TypeUnionDerivation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package caliban.schema

import caliban.introspection.adt.__Type

import scala.quoted.*

object TypeUnionDerivation {
inline def derived[R, T]: Schema[R, T] = ${ typeUnionSchema[R, T] }

def typeUnionSchema[R: Type, T: Type](using quotes: Quotes): Expr[Schema[R, T]] = {
import quotes.reflect.*

class TypeAndSchema[A](val typeRef: String, val schema: Expr[Schema[R, A]], val tpe: Type[A])

def rec[A](using tpe: Type[A]): List[TypeAndSchema[?]] =
TypeRepr.of(using tpe).dealias match {
case OrType(l, r) =>
rec(using l.asType.asInstanceOf[Type[Any]]) ++ rec(using r.asType.asInstanceOf[Type[Any]])
case otherRepr =>
val otherString: String = otherRepr.show
val expr: TypeAndSchema[A] =
Expr.summon[Schema[R, A]] match {
case Some(foundSchema) =>
TypeAndSchema[A](otherString, foundSchema, otherRepr.asType.asInstanceOf[Type[A]])
case None =>
quotes.reflect.report.errorAndAbort(s"Couldn't resolve Schema[Any, $otherString]")
}

List(expr)
}

val typeAndSchemas: List[TypeAndSchema[?]] = rec[T]

val schemaByTypeNameList: Expr[List[(String, Schema[R, Any])]] = Expr.ofList(
typeAndSchemas.map { case (tas: TypeAndSchema[a]) =>
given Type[a] = tas.tpe
'{ (${ Expr(tas.typeRef) }, ${ tas.schema }.asInstanceOf[Schema[R, Any]]) }
}
)
val name = TypeRepr.of[T].show

if (name.contains("|")) {
report.error(
s"You must explicitly add type parameter to derive Schema for a union type in order to capture the name of the type alias"
)
}

'{
val schemaByName: Map[String, Schema[R, Any]] = ${ schemaByTypeNameList }.toMap
new Schema[R, T] {

def resolve(value: T): Step[R] = {
var ret: Step[R] = null
${
Expr.block(
typeAndSchemas.map { case (tas: TypeAndSchema[a]) =>
given Type[a] = tas.tpe

'{ if value.isInstanceOf[a] then ret = schemaByName(${ Expr(tas.typeRef) }).resolve(value) }
},
'{ require(ret != null, s"no schema for ${value}") }
)
}
ret
}

def toType(isInput: Boolean, isSubscription: Boolean): __Type =
Types.makeUnion(
Some(${ Expr(name) }),
None,
schemaByName.values.map(_.toType_(isInput, isSubscription)).toList
)
}
}
}
}
50 changes: 50 additions & 0 deletions core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,56 @@ object Scala3DerivesSpec extends ZIOSpecDefault {
data1 == """{"enum2String":"ENUM1"}""",
data2 == """{"enum2String":"ENUM2"}"""
)
},
test("union type") {
final case class Foo(value: String) derives Schema.SemiAuto
final case class Bar(foo: Int) derives Schema.SemiAuto
final case class Baz(bar: Int) derives Schema.SemiAuto
type Payload = Foo | Bar | Baz

given Schema[Any, Payload] = Schema.unionType[Payload]

final case class QueryInput(isFoo: Boolean) derives ArgBuilder, Schema.SemiAuto
final case class Query(testQuery: QueryInput => zio.UIO[Payload]) derives Schema.SemiAuto

val gql = graphQL(RootResolver(Query(i => ZIO.succeed(if (i.isFoo) Foo("foo") else Bar(1)))))

val expectedSchema =
"""
schema {
query: Query
}
union Payload = Foo | Bar | Baz
type Bar {
foo: Int!
}
type Baz {
bar: Int!
}
type Foo {
value: String!
}
type Query {
testQuery(isFoo: Boolean!): Payload!
}
""".stripMargin
val interpreter = gql.interpreterUnsafe

for {
res1 <- interpreter.execute("{ testQuery(isFoo: true){ ... on Foo { value } } }")
res2 <- interpreter.execute("{ testQuery(isFoo: false){ ... on Bar { foo } } }")
data1 = res1.data.toString
data2 = res2.data.toString
} yield assertTrue(
data1 == """{"testQuery":{"value":"foo"}}""",
data2 == """{"testQuery":{"foo":1}}""",
gql.render == expectedSchema
)
}
)
}
Expand Down

0 comments on commit 08202b8

Please sign in to comment.