Skip to content

Instantly share code, notes, and snippets.

@Baccata
Created October 3, 2024 13:22
Show Gist options
  • Save Baccata/6bf451e5eae5c8de3a941ae2027f0cd9 to your computer and use it in GitHub Desktop.
Save Baccata/6bf451e5eae5c8de3a941ae2027f0cd9 to your computer and use it in GitHub Desktop.
Generically producing scalacheck `Gen` instances from ScalaPB generated code. Mostly useless
package foo
import scalapb.GeneratedMessageCompanion
import scalapb.descriptors._
import org.scalacheck.Gen
import com.google.protobuf.ByteString
import scalapb.GeneratedMessage
object ProtoScalacheck {
/** Generically derives a random generator from the protobuf descriptor associated to a datatype generated by ScalaPB.
* The data generated will be non-sensical but should suffice to assert a few properties.
*/
def messageGen[A <: GeneratedMessage](cmp: GeneratedMessageCompanion[A]): Gen[A] =
pMessageGen(cmp).map(cmp.messageReads.read)
// Inspired from https://github.com/scalapb-json/scalapb-circe/blob/master/core/shared/src/main/scala/scalapb_circe/JsonFormat.scala
private def pMessageGen(cmp: GeneratedMessageCompanion[_]): Gen[PMessage] = {
// Short circuiting on Protobuf structs, which represent arbitrary JSON objects, as it's a recursive data structures
// and traversing it would require protecting against infinite loops, making the logic a lot more complex.
if (cmp.javaDescriptor.getFullName() == com.google.protobuf.Struct.getDescriptor().getFullName()) {
Gen.const(PMessage(Map.empty))
} else {
type Field = (FieldDescriptor, PValue)
Gen
.sequence[Vector[Field], Field](cmp.scalaDescriptor.fields.map(fd => fieldGen(cmp, fd).map(fd -> _)))
.map(_.toMap)
.map(PMessage(_))
}
}
private def fieldGen(cmp: GeneratedMessageCompanion[_], fd: FieldDescriptor): Gen[PValue] = {
if (fd.isMapField) {
val mapEntryCompanion = cmp.messageCompanionForFieldNumber(fd.number)
val mapEntryDesc = fd.scalaType.asInstanceOf[ScalaType.Message].descriptor
val keyDescriptor = mapEntryDesc.findFieldByNumber(1).get
val valueDescriptor = mapEntryDesc.findFieldByNumber(2).get
val entryGen = for {
key <- fieldGen(mapEntryCompanion, keyDescriptor)
value <- fieldGen(mapEntryCompanion, valueDescriptor)
} yield {
PMessage(Map(keyDescriptor -> key, valueDescriptor -> value))
}
Gen.listOf(entryGen).map(_.toVector).map(PRepeated(_))
} else if (fd.isRepeated) {
Gen.listOf(singleValueGen(cmp, fd)).map(_.toVector).map(PRepeated(_))
} else singleValueGen(cmp, fd)
}
private def singleValueGen(cmp: GeneratedMessageCompanion[_], fd: FieldDescriptor): Gen[PValue] = {
fd.scalaType match {
case ScalaType.Enum(ed) => Gen.oneOf(ed.values).map(PEnum(_))
case ScalaType.Long => Gen.long.map(PLong(_))
case ScalaType.Boolean => Gen.oneOf(true, false).map(PBoolean(_))
case ScalaType.Int => Gen.chooseNum(1, 100).map(PInt(_))
case ScalaType.Double => Gen.chooseNum(1d, 100d).map(PDouble(_))
case ScalaType.Float => Gen.chooseNum(1f, 100f).map(PFloat(_))
case ScalaType.String => Gen.identifier.map(PString(_))
case ScalaType.ByteString => Gen.identifier.map(s => PByteString(ByteString.copyFromUtf8(s)))
case ScalaType.Message(_) => pMessageGen(cmp.messageCompanionForFieldNumber(fd.number))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment