Created
September 27, 2024 14:44
-
-
Save pietersp/478432264db080e33c8beda0f23a4f02 to your computer and use it in GitHub Desktop.
Validate json messages against protobuf schema
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//> using scala 2.13.15 | |
//> using dep "com.google.protobuf:protobuf-java:4.28.2" | |
//> using dep "com.google.protobuf:protobuf-java-util:4.28.2" | |
//> using dep "io.circe::circe-parser:0.14.10" | |
import com.google.protobuf.DescriptorProtos.FileDescriptorSet | |
import com.google.protobuf.Descriptors.{Descriptor, FileDescriptor} | |
import com.google.protobuf.DynamicMessage | |
import com.google.protobuf.util.JsonFormat | |
import scala.util.Using | |
import scala.util.Success | |
import scala.util.Failure | |
import java.io.FileInputStream | |
import scala.jdk.CollectionConverters.CollectionHasAsScala | |
import scala.util.Try | |
// This script does its best to validate a json message against a protobuf schema. | |
// This is useful for validating messages that are sent to a pubsub topic | |
// This checks that the required fields are all present as well as whether any fields in the | |
// JSON event are not present in the schema. | |
// | |
// Run this script as follows: | |
// protoc --descriptor_set_out=gbq.pb gbq.proto && scala-cli run ProtoValidator.sc | |
// Where gbq.proto is the schema that the pubsub topic is defined with | |
// Also make sure there is a sample of an event in the current directory (called data.json) | |
val schemaFile = "gbq.pb" // The compiled schema created by running `protoc --descriptor_set_out=gbq.pb gbq.proto` | |
val eventJsonFile = "data.json" // The sample json that we should validate | |
val messageTypeName = "DriverAllocationRecord" // Specify the message type name | |
case class SchemaException( | |
message: String, | |
inputJson: Option[String] = None, | |
outputJson: Option[String] = None | |
) extends RuntimeException() { | |
override def toString(): String = { | |
s"SchemaException: $message\n" + | |
s"Input JSON:\n ${inputJson.getOrElse("N/A")}\n" + | |
s"Output JSON:\n ${outputJson.getOrElse("N/A")}\n" | |
} | |
} | |
// Get the descriptor that matches the message type we are interested in | |
def selectDescriptor(descriptorSet: FileDescriptorSet): Try[Descriptor] = | |
descriptorSet.getFileList.asScala | |
.map(file => FileDescriptor.buildFrom(file, Array.empty)) | |
.toList | |
.flatMap(_.getMessageTypes.asScala) | |
.find(_.getName == messageTypeName) | |
.toRight(SchemaException(s"Message type $messageTypeName not found in descriptor set")) | |
.toTry | |
def compareJson(input: String, output: String): Try[String] = { | |
def parseAndSort(input: String) = { | |
import io.circe._ | |
import io.circe.parser._ | |
parse(input).map(_.noSpacesSortKeys) | |
} | |
(parseAndSort(input), parseAndSort(output)) match { | |
case (Right(a), Right(b)) if a == b => | |
Success("JSON is valid according to the Protobuf schema!") | |
case (Right(a), Right(b)) => | |
Failure(SchemaException("Original and round-tripped are valid JSON but do not match", Some(a), Some(b))) | |
case (Left(e), _) => | |
Failure(SchemaException(s"Original JSON is invalid: ${e.message}", inputJson = Some(input))) | |
case (_, Left(e)) => | |
Failure(SchemaException(s"Round-tripped JSON is invalid: ${e.message}", outputJson = Some(output))) | |
} | |
} | |
// Try to parse the JSON against the schema (descriptor) | |
def validateJson(inputJsonString: String, descriptor: Descriptor): Try[String] = { | |
import JsonFormat._ | |
val typeRegistry = TypeRegistry.newBuilder().add(descriptor).build() | |
val builder: DynamicMessage.Builder = DynamicMessage.newBuilder(descriptor) | |
// Create a custom parser with the TypeRegistry | |
val parser: Parser = { | |
val p = JsonFormat.parser().usingTypeRegistry(typeRegistry) | |
p.merge(inputJsonString, builder) | |
p | |
} | |
val printer: Printer = JsonFormat | |
.printer() | |
.usingTypeRegistry(typeRegistry) | |
.preservingProtoFieldNames() | |
for { | |
message <- Try(builder.build()) // This will throw an exception if the JSON is invalid acording to the schema | |
outputJsonString <- Try(printer.print(message)) // Convert the message back to JSON | |
output <- compareJson(inputJsonString, outputJsonString) // Compare the original JSON with the round-tripped JSON | |
} yield output | |
} | |
val result: Try[String] = for { | |
input <- Using(scala.io.Source.fromFile(eventJsonFile))(_.mkString) | |
descriptorSet <- Using(new FileInputStream(schemaFile))( | |
FileDescriptorSet.parseFrom(_) | |
) | |
descriptor <- selectDescriptor(descriptorSet) | |
validationResult <- validateJson(input, descriptor) | |
} yield validationResult | |
println(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment