Last active
April 5, 2024 18:56
-
-
Save seddonm1/6e4354f68d7d74f71bb5e27d31d5b980 to your computer and use it in GitHub Desktop.
Makes a Spark Schema (StructType) from an input XSD file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// need to add the Apache WS XMLSchema library to spark/jars (does not have dependencies) | |
// https://repo1.maven.org/maven2/org/apache/ws/xmlschema/xmlschema-core/2.2.5/xmlschema-core-2.2.5.jar | |
import org.apache.ws.commons.schema.XmlSchemaCollection | |
import java.io.StringReader | |
import scala.collection.JavaConverters._ | |
import org.apache.ws.commons.schema._ | |
import org.apache.ws.commons.schema.constants.Constants | |
import org.apache.spark.sql.types._ | |
def getStructField(xmlSchema: XmlSchema, schemaType: XmlSchemaType): StructField = { | |
schemaType match { | |
// xs:simpleType | |
case schemaSimpleType: XmlSchemaSimpleType => { | |
schemaSimpleType.getContent match { | |
case schemaSimpleTypeRestriction: XmlSchemaSimpleTypeRestriction => { | |
val baseName = "baseName" | |
val matchType = if (schemaSimpleTypeRestriction.getBaseTypeName == Constants.XSD_ANYSIMPLETYPE) { | |
schemaSimpleType.getQName | |
} else { | |
schemaSimpleTypeRestriction.getBaseTypeName | |
} | |
matchType match { | |
case Constants.XSD_BASE64 => StructField(baseName, StringType, true) | |
case Constants.XSD_BOOLEAN => StructField(baseName, BooleanType, true) | |
case Constants.XSD_BYTE => StructField(baseName, BinaryType, true) | |
case Constants.XSD_DATE => StructField(baseName, StringType, true) | |
case Constants.XSD_DATETIME => StructField(baseName, StringType, true) | |
case Constants.XSD_DECIMAL => { | |
val scale = schemaSimpleTypeRestriction.getFacets.asScala.toList.collect { | |
case schemaFractionDigitsFacet: XmlSchemaFractionDigitsFacet => schemaFractionDigitsFacet | |
}.headOption | |
scale match { | |
case Some(scale) => StructField(baseName, DecimalType(38, scale.getValue.asInstanceOf[String].toInt), true) | |
case None => StructField(baseName, DecimalType(38, 18), true) | |
} | |
} | |
case Constants.XSD_DOUBLE => StructField(baseName, DoubleType, true) | |
case Constants.XSD_FLOAT => StructField(baseName, FloatType, true) | |
case Constants.XSD_INTEGER => StructField(baseName, IntegerType, true) | |
case Constants.XSD_LONG => StructField(baseName, LongType, true) | |
case Constants.XSD_NEGATIVEINTEGER => StructField(baseName, IntegerType, true) | |
case Constants.XSD_NONNEGATIVEINTEGER => StructField(baseName, IntegerType, true) | |
case Constants.XSD_NONPOSITIVEINTEGER => StructField(baseName, IntegerType, true) | |
case Constants.XSD_POSITIVEINTEGER => StructField(baseName, IntegerType, true) | |
case Constants.XSD_SHORT => StructField(baseName, IntegerType, true) | |
case Constants.XSD_STRING => StructField(baseName, StringType, true) | |
case Constants.XSD_TIME => StructField(baseName, StringType, true) | |
case Constants.XSD_UNSIGNEDINT => StructField(baseName, IntegerType, true) | |
case Constants.XSD_UNSIGNEDLONG => StructField(baseName, IntegerType, true) | |
case Constants.XSD_UNSIGNEDSHORT => StructField(baseName, IntegerType, true) | |
} | |
} | |
} | |
} | |
// xs:complexType | |
case schemaComplexType: XmlSchemaComplexType => { | |
Option(schemaComplexType.getContentModel) match { | |
case Some(contentModel) => contentModel match { | |
// xs:simpleContent | |
case simpleContent: XmlSchemaSimpleContent => { | |
simpleContent.getContent match { | |
case schemaSimpleContentExtension: XmlSchemaSimpleContentExtension => { | |
val value = { | |
val baseStructField = getStructField(xmlSchema, xmlSchema.getTypeByName(schemaSimpleContentExtension.getBaseTypeName)) | |
StructField("_VALUE", baseStructField.dataType, true) | |
} | |
val attributes = schemaSimpleContentExtension.getAttributes.asScala.toList.map { attribute => | |
attribute match { | |
case schemaAttribute: XmlSchemaAttribute => { | |
val baseStructField = getStructField(xmlSchema, xmlSchema.getTypeByName(schemaAttribute.getSchemaTypeName)) | |
StructField(s"_${schemaAttribute.getName}", baseStructField.dataType, true) | |
} | |
} | |
} | |
StructField(schemaComplexType.getName, StructType(List(value) ++ attributes), true) | |
} | |
} | |
} | |
} | |
case None => { | |
schemaComplexType.getParticle match { | |
// xs:all | |
case schemaAll: XmlSchemaAll => { | |
val fields = schemaAll.getItems.asScala.toList.map { element => | |
element match { | |
case schemaElement: XmlSchemaElement => { | |
val baseStructField = getStructField(xmlSchema, schemaElement.getSchemaType) | |
val field = StructField(schemaElement.getName, baseStructField.dataType, true) | |
if (schemaElement.getMaxOccurs == 1) { | |
field | |
} else { | |
val field = StructField(schemaElement.getName, baseStructField.dataType, true) | |
StructField(schemaElement.getName, ArrayType(field.dataType, true), true) | |
} | |
} | |
} | |
} | |
StructField(schemaComplexType.getName, StructType(fields), true) | |
} | |
// xs:choice | |
case schemaChoice: XmlSchemaChoice => { | |
val fields = schemaChoice.getItems.asScala.toList.map { element => | |
element match { | |
case schemaElement: XmlSchemaElement => { | |
val baseStructField = getStructField(xmlSchema, schemaElement.getSchemaType) | |
val field = StructField(schemaElement.getName, baseStructField.dataType, true) | |
if (schemaElement.getMaxOccurs == 1) { | |
field | |
} else { | |
val field = StructField(schemaElement.getName, baseStructField.dataType, true) | |
StructField(schemaElement.getName, ArrayType(field.dataType, true), true) | |
} | |
} | |
} | |
} | |
StructField(schemaComplexType.getName, StructType(fields), true) | |
} | |
// xs:sequence | |
case schemaSequence: XmlSchemaSequence => { | |
// flatten xs:choice nodes | |
val fields = schemaSequence.getItems.asScala.toList.flatMap { schemaSequenceMember: XmlSchemaSequenceMember => | |
schemaSequenceMember match { | |
case schemaChoice: XmlSchemaChoice => schemaChoice.getItems.asScala.toList.map((_, true)) | |
case schemaElement: XmlSchemaElement => List((schemaElement, schemaElement.getMinOccurs == 0)) | |
} | |
}.map { case (element, nullable) => | |
element match { | |
case schemaElement: XmlSchemaElement => { | |
val baseStructField = getStructField(xmlSchema, schemaElement.getSchemaType) | |
val field = StructField(schemaElement.getName, baseStructField.dataType, nullable) | |
if (schemaElement.getMaxOccurs == 1) { | |
field | |
} else { | |
val field = StructField(schemaElement.getName, baseStructField.dataType, nullable) | |
StructField(schemaElement.getName, ArrayType(field.dataType, true), true) | |
} | |
} | |
} | |
} | |
StructField(schemaComplexType.getName, StructType(fields), true) | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
def getStructType(xmlSchema: XmlSchema): StructType = { | |
val baseElement = xmlSchema.getElements.asScala.head | |
val schemaType = baseElement._2.getSchemaType | |
if (schemaType.isAnonymous) { | |
schemaType.setName(baseElement._1.getLocalPart) | |
} | |
StructType(getStructField(xmlSchema, schemaType) :: Nil) | |
} | |
// read the XSD | |
val df = spark.read.option("wholetext", "true").text("/src/pain.001.001.03.xsd") | |
val xmlSchemaCollection = new XmlSchemaCollection | |
val xmlSchema = xmlSchemaCollection.read(new StringReader(df.head.getString(0))) | |
val sparkSchema = getStructType(xmlSchema) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment