Skip to content

Commit

Permalink
Support proto3 optional (#818)
Browse files Browse the repository at this point in the history
* Support proto3 optional

* Fix protobuf refined
  • Loading branch information
RustedBones authored Nov 7, 2023
1 parent 745b9cc commit ca4c0d6
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 133 deletions.
2 changes: 0 additions & 2 deletions docs/protobuf.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ implicit val efEnum = ProtobufField.enum[Color.Type, ColorProto]

Additional `ProtobufField[T]` instances for `Byte`, `Char`, `Short`, and `UnsafeEnum[T]` are available from `import magnolify.protobuf.unsafe._`. These conversions are unsafe due to potential overflow.

By default nullable type `Option[T]` is not supported when `MsgT` is compiled with Protobuf 3 syntax. This is because Protobuf 3 does not offer a way to check if a field was set, and instead returns `0`, `""`, `false`, etc. when it was not. You can enable Protobuf 3 support for `Option[T]` by adding `import magnolify.protobuf.unsafe.Proto3Option._`. However with this, Scala `None`s will become `0/""/false` in Protobuf and come back as `Some(0/""/false)`.

To use a different field case format in target records, add an optional `CaseMapper` argument to `ProtobufType`. The following example maps `firstName` & `lastName` to `first_name` & `last_name`.

```scala
Expand Down
67 changes: 16 additions & 51 deletions protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,31 @@
package magnolify.protobuf

import java.lang.reflect.Method
import java.{util => ju}

import com.google.protobuf.Descriptors.FileDescriptor.Syntax
import java.util as ju
import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor}
import com.google.protobuf.{ByteString, Message, ProtocolMessageEnum}
import magnolia1._
import magnolify.shared._
import magnolia1.*
import magnolify.shared.*
import magnolify.shims.FactoryCompat

import scala.annotation.implicitNotFound
import scala.collection.concurrent
import scala.reflect.ClassTag
import scala.jdk.CollectionConverters._
import scala.collection.compat._
import scala.jdk.CollectionConverters.*
import scala.collection.compat.*

sealed trait ProtobufType[T, MsgT <: Message] extends Converter[T, MsgT, MsgT] {
def apply(r: MsgT): T = from(r)
def apply(t: T): MsgT = to(t)
}

sealed trait ProtobufOption {
def check(f: ProtobufField.Record[_], syntax: Syntax): Unit
}

object ProtobufOption {
implicit val proto2Option: ProtobufOption = new ProtobufOption {
override def check(f: ProtobufField.Record[_], syntax: Syntax): Unit =
if (f.hasOptional) {
require(
syntax == Syntax.PROTO2,
"Option[T] support is PROTO2 only, " +
"`import magnolify.protobuf.unsafe.Proto3Option._` to enable PROTO3 support"
)
}
}

private[protobuf] class Proto3Option extends ProtobufOption {
override def check(f: ProtobufField.Record[_], syntax: Syntax): Unit = ()
}
}

object ProtobufType {
implicit def apply[T: ProtobufField, MsgT <: Message: ClassTag](implicit
po: ProtobufOption
): ProtobufType[T, MsgT] = ProtobufType(CaseMapper.identity)
implicit def apply[T: ProtobufField, MsgT <: Message: ClassTag]: ProtobufType[T, MsgT] =
ProtobufType(CaseMapper.identity)

def apply[T, MsgT <: Message](cm: CaseMapper)(implicit
f: ProtobufField[T],
ct: ClassTag[MsgT],
po: ProtobufOption
ct: ClassTag[MsgT]
): ProtobufType[T, MsgT] = f match {
case r: ProtobufField.Record[_] =>
new ProtobufType[T, MsgT] {
Expand All @@ -74,9 +50,7 @@ object ProtobufType {
.getMethod("getDescriptor")
.invoke(null)
.asInstanceOf[Descriptor]
if (r.hasOptional) {
po.check(r, descriptor.getFile.getSyntax)
}

r.checkDefaults(descriptor)(cm)
}

Expand All @@ -101,7 +75,6 @@ sealed trait ProtobufField[T] extends Serializable {
type FromT
type ToT

val hasOptional: Boolean
val default: Option[T]

def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = ()
Expand Down Expand Up @@ -133,7 +106,7 @@ object ProtobufField {
new ProtobufField[T] {
override type FromT = tc.FromT
override type ToT = tc.ToT
override val hasOptional: Boolean = tc.hasOptional

override val default: Option[T] = tc.default.map(x => caseClass.construct(_ => x))
override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm))
override def to(v: T, b: Message.Builder)(cm: CaseMapper): ToT =
Expand All @@ -157,14 +130,11 @@ object ProtobufField {
}
)

override val hasOptional: Boolean = caseClass.parameters.exists(_.typeclass.hasOptional)

override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = {
val syntax = descriptor.getFile.getSyntax
val fields = getFields(descriptor)(cm)
caseClass.parameters.foreach { p =>
val field = fields(p.index)
val protoDefault = if (syntax == Syntax.PROTO2 && field.hasDefaultValue) {
val protoDefault = if (field.hasDefaultValue) {
Some(p.typeclass.fromAny(field.getDefaultValue)(cm))
} else {
p.typeclass.default
Expand All @@ -183,13 +153,12 @@ object ProtobufField {

override def from(v: Message)(cm: CaseMapper): T = {
val descriptor = v.getDescriptorForType
val syntax = descriptor.getFile.getSyntax
val fields = getFields(descriptor)(cm)

caseClass.construct { p =>
val field = fields(p.index)
// hasField behaves correctly on PROTO2 optional fields
val value = if (syntax == Syntax.PROTO2 && field.isOptional && !v.hasField(field)) {
// check hasPresence to make sure hasField is meaningful
val value = if (field.hasPresence && !v.hasField(field)) {
null
} else {
v.getField(field)
Expand Down Expand Up @@ -234,7 +203,6 @@ object ProtobufField {
class FromWord[T] {
def apply[U](f: T => U)(g: U => T)(implicit pf: ProtobufField[T]): ProtobufField[U] =
new Aux[U, pf.FromT, pf.ToT] {
override val hasOptional: Boolean = pf.hasOptional
override val default: Option[U] = pf.default.map(f)
override def from(v: FromT)(cm: CaseMapper): U = f(pf.from(v)(cm))
override def to(v: U, b: Message.Builder)(cm: CaseMapper): ToT = pf.to(g(v), null)(cm)
Expand All @@ -243,7 +211,6 @@ object ProtobufField {

private def aux[T, From, To](_default: T)(f: From => T)(g: T => To): ProtobufField[T] =
new Aux[T, From, To] {
override val hasOptional: Boolean = false
override val default: Option[T] = Some(_default)
override def from(v: FromT)(cm: CaseMapper): T = f(v)
override def to(v: T, b: Message.Builder)(cm: CaseMapper): ToT = g(v)
Expand All @@ -262,9 +229,9 @@ object ProtobufField {
implicit val pfString: ProtobufField[String] = id[String]("")
implicit val pfByteString: ProtobufField[ByteString] = id[ByteString](ByteString.EMPTY)
implicit val pfByteArray: ProtobufField[Array[Byte]] =
aux2[Array[Byte], ByteString](Array.emptyByteArray)(_.toByteArray)(ByteString.copyFrom)
aux2[Array[Byte], ByteString](Array.emptyByteArray)(b => b.toByteArray)(ByteString.copyFrom)

def `enum`[T, E <: Enum[E] with ProtocolMessageEnum](implicit
implicit def `enum`[T, E <: Enum[E] with ProtocolMessageEnum](implicit
et: EnumType[T],
ct: ClassTag[E]
): ProtobufField[T] = {
Expand All @@ -282,7 +249,6 @@ object ProtobufField {

implicit def pfOption[T](implicit f: ProtobufField[T]): ProtobufField[Option[T]] =
new Aux[Option[T], f.FromT, f.ToT] {
override val hasOptional: Boolean = true
override val default: Option[Option[T]] = f.default match {
case Some(v) => Some(Some(v))
case None => None
Expand All @@ -306,7 +272,6 @@ object ProtobufField {
fc: FactoryCompat[T, C[T]]
): ProtobufField[C[T]] =
new Aux[C[T], ju.List[f.FromT], ju.List[f.ToT]] {
override val hasOptional: Boolean = false
override val default: Option[C[T]] = Some(fc.newBuilder.result())
override def from(v: ju.List[f.FromT])(cm: CaseMapper): C[T] = {
val b = fc.newBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ package object unsafe {
implicit val pfChar: ProtobufField[Char] = ProtobufField.from[Int](_.toChar)(_.toInt)
implicit val pfShort: ProtobufField[Short] = ProtobufField.from[Int](_.toShort)(_.toInt)

object Proto3Option {
implicit val proto3Option: ProtobufOption = new ProtobufOption.Proto3Option
}

implicit def pfUnsafeEnum[T: EnumType]: ProtobufField[UnsafeEnum[T]] =
ProtobufField
.from[String](s => Option(s).filter(_.nonEmpty).map(UnsafeEnum.from[T]).orNull)(
Expand Down
25 changes: 16 additions & 9 deletions protobuf/src/test/protobuf/Proto3.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ message FloatsP3 {
double d = 2;
}

message SingularP3 {
message RequiredP3 {
bool b = 1;
string s = 2;
int32 i = 3;
}

message NullableP3 {
optional bool b = 1;
optional string s = 2;
optional int32 i = 3;
}

message RepeatedP3 {
repeated bool b = 1;
repeated string s = 2;
Expand All @@ -30,8 +36,9 @@ message NestedP3 {
bool b = 1;
string s = 2;
int32 i = 3;
SingularP3 r = 4;
repeated SingularP3 l = 5;
RequiredP3 r = 4;
optional RequiredP3 o = 5;
repeated RequiredP3 l = 6;
}

message CollectionP3 {
Expand Down Expand Up @@ -60,9 +67,9 @@ message EnumsP3 {
JavaEnums j = 1;
ScalaEnums s = 2; // Enumeration
ScalaEnums a = 3; // ADT
JavaEnums jo = 4;
ScalaEnums so = 5; // Enumeration
ScalaEnums ao = 6; // ADT
optional JavaEnums jo = 4;
optional ScalaEnums so = 5; // Enumeration
optional ScalaEnums ao = 6; // ADT
repeated JavaEnums jr = 7;
repeated ScalaEnums sr = 8; // Enumeration
repeated ScalaEnums ar = 9; // ADT
Expand All @@ -72,9 +79,9 @@ message UnsafeEnumsP3 {
string j = 1;
string s = 2;
string a = 3;
string jo = 4;
string so = 5;
string ao = 6;
optional string jo = 4;
optional string so = 5;
optional string ao = 6;
repeated string jr = 7;
repeated string sr = 8;
repeated string ar = 9;
Expand Down
62 changes: 8 additions & 54 deletions protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,43 +60,26 @@ class ProtobufTypeSuite extends BaseProtobufTypeSuite {
test[Floats, FloatsP2]
test[Floats, FloatsP3]
test[Required, RequiredP2]
test[Required, SingularP3]
test[Required, RequiredP3]
test[Nullable, NullableP2]
test[Nullable, NullableP3]

test[Repeated, RepeatedP2]
test[Repeated, RepeatedP3]
test[Nested, NestedP2]
test[NestedNoOption, NestedP3]
test[Nested, NestedP3]
test[UnsafeByte, IntegersP2]
test[UnsafeChar, IntegersP2]
test[UnsafeShort, IntegersP2]

test[Collections, CollectionP2]
test[MoreCollections, MoreCollectionP2]
test[Collections, CollectionP3]
test[MoreCollections, MoreCollectionP2]
test[MoreCollections, MoreCollectionP3]

// PROTO3 removes the notion of require vs optional fields.
// By default `Option[T] are not supported`.
test("Fail PROTO3 Option[T]") {
val msg = "requirement failed: Option[T] support is PROTO2 only, " +
"`import magnolify.protobuf.unsafe.Proto3Option._` to enable PROTO3 support"
interceptMessage[IllegalArgumentException](msg)(ProtobufType[Nullable, SingularP3])
}

// Adding `import magnolify.protobuf.unsafe.Proto3Option._` enables PROTO3 `Option[T]` support.
// The new singular field returns default value if unset.
// Hence `None` round trips back as `Some(false/0/"")`.
{
import magnolify.protobuf.unsafe.Proto3Option._
implicit val eq: Eq[Nullable] = Eq.by { x =>
Required(x.b.getOrElse(false), x.i.getOrElse(0), x.s.getOrElse(""))
}
test[Nullable, SingularP3]
}

test("AnyVal") {
test[ProtoHasValueClass, IntegersP2]
test[ProtoHasValueClass, IntegersP3]
}
}

Expand All @@ -121,23 +104,7 @@ class MoreProtobufTypeSuite extends BaseProtobufTypeSuite {

{
import Proto3Enums._
import magnolify.protobuf.unsafe.Proto3Option._
// Enums are encoded as integers and default to zero value
implicit val eq: Eq[Enums] = Eq.by(e =>
(
e.j,
e.s,
e.a,
e.jo.getOrElse(JavaEnums.Color.RED),
e.so.getOrElse(ScalaEnums.Color.Red),
e.ao.getOrElse(ADT.Red),
e.jr,
e.sr,
e.ar
)
)
test[Enums, EnumsP3]
// Unsafe enums are encoded as string and default "" is treated as None
test[UnsafeEnums, UnsafeEnumsP3]
}

Expand All @@ -160,16 +127,12 @@ class MoreProtobufTypeSuite extends BaseProtobufTypeSuite {
import Proto3Enums._
test[DefaultIntegers3, IntegersP3]
test[DefaultFloats3, FloatsP3]
test[DefaultRequired3, SingularP3]
test[DefaultRequired3, RequiredP3]
test[DefaultEnums3, EnumsP3]
}

{
import magnolify.protobuf.unsafe.Proto3Option._
implicit val eq: Eq[DefaultNullable3] = Eq.by { x =>
Required(x.b.getOrElse(false), x.i.getOrElse(0), x.s.getOrElse(""))
}
test[DefaultNullable3, SingularP3]
test[DefaultNullable3, NullableP3]
}

{
Expand All @@ -178,14 +141,13 @@ class MoreProtobufTypeSuite extends BaseProtobufTypeSuite {
testFail[F, DefaultMismatch2](ProtobufType[DefaultMismatch2, DefaultRequiredP2])(
"Default mismatch magnolify.protobuf.DefaultMismatch2#i: 321 != 123"
)
testFail[F, DefaultMismatch3](ProtobufType[DefaultMismatch3, SingularP3])(
testFail[F, DefaultMismatch3](ProtobufType[DefaultMismatch3, RequiredP3])(
"Default mismatch magnolify.protobuf.DefaultMismatch3#i: 321 != 0"
)
}
}

object Proto2Enums {
// FIXME: for some reasons these implicits fail to resolve without explicit types
implicit val efJavaEnum2: ProtobufField[JavaEnums.Color] =
ProtobufField.enum[JavaEnums.Color, EnumsP2.JavaEnums]
implicit val efScalaEnum2: ProtobufField[ScalaEnums.Color.Type] =
Expand All @@ -195,7 +157,6 @@ object Proto2Enums {
}

object Proto3Enums {
// FIXME: for some reasons these implicits fail to resolve without explicit types
implicit val efJavaEnum3: ProtobufField[JavaEnums.Color] =
ProtobufField.enum[JavaEnums.Color, EnumsP3.JavaEnums]
implicit val efScalaEnum3: ProtobufField[ScalaEnums.Color.Type] =
Expand All @@ -211,13 +172,6 @@ case class UnsafeChar(i: Char, l: Long)
case class UnsafeShort(i: Short, l: Long)
case class BytesA(b: ByteString)
case class BytesB(b: Array[Byte])
case class NestedNoOption(
b: Boolean,
i: Int,
s: String,
r: Required,
l: List[Required]
)

case class DefaultsRequired2(
i: Int = 123,
Expand Down
Loading

0 comments on commit ca4c0d6

Please sign in to comment.