Skip to content

Instantly share code, notes, and snippets.

@pietersp
Created September 27, 2024 14:44
Show Gist options
  • Save pietersp/478432264db080e33c8beda0f23a4f02 to your computer and use it in GitHub Desktop.
Save pietersp/478432264db080e33c8beda0f23a4f02 to your computer and use it in GitHub Desktop.
Validate json messages against protobuf schema
//> using scala 2.13.15
//> using dep "com.google.protobuf:protobuf-java:4.28.2"
//> using dep "com.google.protobuf:protobuf-java-util:4.28.2"
//> using dep "io.circe::circe-parser:0.14.10"
import com.google.protobuf.DescriptorProtos.FileDescriptorSet
import com.google.protobuf.Descriptors.{Descriptor, FileDescriptor}
import com.google.protobuf.DynamicMessage
import com.google.protobuf.util.JsonFormat
import scala.util.Using
import scala.util.Success
import scala.util.Failure
import java.io.FileInputStream
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.util.Try
// This script does its best to validate a json message against a protobuf schema.
// This is useful for validating messages that are sent to a pubsub topic
// This checks that the required fields are all present as well as whether any fields in the
// JSON event are not present in the schema.
//
// Run this script as follows:
// protoc --descriptor_set_out=gbq.pb gbq.proto && scala-cli run ProtoValidator.sc
// Where gbq.proto is the schema that the pubsub topic is defined with
// Also make sure there is a sample of an event in the current directory (called data.json)
val schemaFile = "gbq.pb" // The compiled schema created by running `protoc --descriptor_set_out=gbq.pb gbq.proto`
val eventJsonFile = "data.json" // The sample json that we should validate
val messageTypeName = "DriverAllocationRecord" // Specify the message type name
case class SchemaException(
message: String,
inputJson: Option[String] = None,
outputJson: Option[String] = None
) extends RuntimeException() {
override def toString(): String = {
s"SchemaException: $message\n" +
s"Input JSON:\n ${inputJson.getOrElse("N/A")}\n" +
s"Output JSON:\n ${outputJson.getOrElse("N/A")}\n"
}
}
// Get the descriptor that matches the message type we are interested in
def selectDescriptor(descriptorSet: FileDescriptorSet): Try[Descriptor] =
descriptorSet.getFileList.asScala
.map(file => FileDescriptor.buildFrom(file, Array.empty))
.toList
.flatMap(_.getMessageTypes.asScala)
.find(_.getName == messageTypeName)
.toRight(SchemaException(s"Message type $messageTypeName not found in descriptor set"))
.toTry
def compareJson(input: String, output: String): Try[String] = {
def parseAndSort(input: String) = {
import io.circe._
import io.circe.parser._
parse(input).map(_.noSpacesSortKeys)
}
(parseAndSort(input), parseAndSort(output)) match {
case (Right(a), Right(b)) if a == b =>
Success("JSON is valid according to the Protobuf schema!")
case (Right(a), Right(b)) =>
Failure(SchemaException("Original and round-tripped are valid JSON but do not match", Some(a), Some(b)))
case (Left(e), _) =>
Failure(SchemaException(s"Original JSON is invalid: ${e.message}", inputJson = Some(input)))
case (_, Left(e)) =>
Failure(SchemaException(s"Round-tripped JSON is invalid: ${e.message}", outputJson = Some(output)))
}
}
// Try to parse the JSON against the schema (descriptor)
def validateJson(inputJsonString: String, descriptor: Descriptor): Try[String] = {
import JsonFormat._
val typeRegistry = TypeRegistry.newBuilder().add(descriptor).build()
val builder: DynamicMessage.Builder = DynamicMessage.newBuilder(descriptor)
// Create a custom parser with the TypeRegistry
val parser: Parser = {
val p = JsonFormat.parser().usingTypeRegistry(typeRegistry)
p.merge(inputJsonString, builder)
p
}
val printer: Printer = JsonFormat
.printer()
.usingTypeRegistry(typeRegistry)
.preservingProtoFieldNames()
for {
message <- Try(builder.build()) // This will throw an exception if the JSON is invalid acording to the schema
outputJsonString <- Try(printer.print(message)) // Convert the message back to JSON
output <- compareJson(inputJsonString, outputJsonString) // Compare the original JSON with the round-tripped JSON
} yield output
}
val result: Try[String] = for {
input <- Using(scala.io.Source.fromFile(eventJsonFile))(_.mkString)
descriptorSet <- Using(new FileInputStream(schemaFile))(
FileDescriptorSet.parseFrom(_)
)
descriptor <- selectDescriptor(descriptorSet)
validationResult <- validateJson(input, descriptor)
} yield validationResult
println(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment