Created
October 3, 2024 13:22
-
-
Save Baccata/6bf451e5eae5c8de3a941ae2027f0cd9 to your computer and use it in GitHub Desktop.
Generically producing scalacheck `Gen` instances from ScalaPB generated code. Mostly useless
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package foo | |
import scalapb.GeneratedMessageCompanion | |
import scalapb.descriptors._ | |
import org.scalacheck.Gen | |
import com.google.protobuf.ByteString | |
import scalapb.GeneratedMessage | |
object ProtoScalacheck { | |
/** Generically derives a random generator from the protobuf descriptor associated to a datatype generated by ScalaPB. | |
* The data generated will be non-sensical but should suffice to assert a few properties. | |
*/ | |
def messageGen[A <: GeneratedMessage](cmp: GeneratedMessageCompanion[A]): Gen[A] = | |
pMessageGen(cmp).map(cmp.messageReads.read) | |
// Inspired from https://github.com/scalapb-json/scalapb-circe/blob/master/core/shared/src/main/scala/scalapb_circe/JsonFormat.scala | |
private def pMessageGen(cmp: GeneratedMessageCompanion[_]): Gen[PMessage] = { | |
// Short circuiting on Protobuf structs, which represent arbitrary JSON objects, as it's a recursive data structures | |
// and traversing it would require protecting against infinite loops, making the logic a lot more complex. | |
if (cmp.javaDescriptor.getFullName() == com.google.protobuf.Struct.getDescriptor().getFullName()) { | |
Gen.const(PMessage(Map.empty)) | |
} else { | |
type Field = (FieldDescriptor, PValue) | |
Gen | |
.sequence[Vector[Field], Field](cmp.scalaDescriptor.fields.map(fd => fieldGen(cmp, fd).map(fd -> _))) | |
.map(_.toMap) | |
.map(PMessage(_)) | |
} | |
} | |
private def fieldGen(cmp: GeneratedMessageCompanion[_], fd: FieldDescriptor): Gen[PValue] = { | |
if (fd.isMapField) { | |
val mapEntryCompanion = cmp.messageCompanionForFieldNumber(fd.number) | |
val mapEntryDesc = fd.scalaType.asInstanceOf[ScalaType.Message].descriptor | |
val keyDescriptor = mapEntryDesc.findFieldByNumber(1).get | |
val valueDescriptor = mapEntryDesc.findFieldByNumber(2).get | |
val entryGen = for { | |
key <- fieldGen(mapEntryCompanion, keyDescriptor) | |
value <- fieldGen(mapEntryCompanion, valueDescriptor) | |
} yield { | |
PMessage(Map(keyDescriptor -> key, valueDescriptor -> value)) | |
} | |
Gen.listOf(entryGen).map(_.toVector).map(PRepeated(_)) | |
} else if (fd.isRepeated) { | |
Gen.listOf(singleValueGen(cmp, fd)).map(_.toVector).map(PRepeated(_)) | |
} else singleValueGen(cmp, fd) | |
} | |
private def singleValueGen(cmp: GeneratedMessageCompanion[_], fd: FieldDescriptor): Gen[PValue] = { | |
fd.scalaType match { | |
case ScalaType.Enum(ed) => Gen.oneOf(ed.values).map(PEnum(_)) | |
case ScalaType.Long => Gen.long.map(PLong(_)) | |
case ScalaType.Boolean => Gen.oneOf(true, false).map(PBoolean(_)) | |
case ScalaType.Int => Gen.chooseNum(1, 100).map(PInt(_)) | |
case ScalaType.Double => Gen.chooseNum(1d, 100d).map(PDouble(_)) | |
case ScalaType.Float => Gen.chooseNum(1f, 100f).map(PFloat(_)) | |
case ScalaType.String => Gen.identifier.map(PString(_)) | |
case ScalaType.ByteString => Gen.identifier.map(s => PByteString(ByteString.copyFromUtf8(s))) | |
case ScalaType.Message(_) => pMessageGen(cmp.messageCompanionForFieldNumber(fd.number)) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment