Last active
November 18, 2017 17:06
-
-
Save vsetka/499d066586f723fb7fda38866f1481f2 to your computer and use it in GitHub Desktop.
Scala/Spark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class City( | |
id: Int, | |
name: String, | |
iata: Option[String], | |
longitude: Option[Double], | |
latitude: Option[Double], | |
updatedTimestamp: Long) | |
object udfHelpers { | |
val toTimestamp = udf((date: String) => new SimpleDateFormat("yyyy-MM-dd HH:mm").parse(date).getTime) | |
} | |
val spark = SparkSession | |
.builder | |
.appName("Spark Test") | |
.master("local[*]") | |
.getOrCreate | |
import spark.implicits._ | |
def getSqlDf(host: String, port: Int, name: String, user: String, password: String, tableName: String): sql.DataFrame = { | |
spark.read | |
.format("jdbc") | |
.option("driver", "org.postgresql.Driver") | |
.option("url", s"jdbc:postgresql://${host}:${port}/${name}?zeroDateTimeBehavior=convertToNull&read_buffer_size=100M") | |
.option("dbtable", tableName) | |
.option("user", user) | |
.option("password", password) | |
.load | |
} | |
val citiesDF = getSqlDf( | |
host = Settings.geoDb.host, | |
port = Settings.geoDb.port, | |
name = Settings.geoDb.name, | |
user = Settings.geoDb.user, | |
password = Settings.geoDb.password, | |
tableName = "geo.\"Cities\"" | |
).withColumn("updatedTimestamp", udfHelpers.toTimestamp(citiesRaw("updatedAt"))) | |
.filter("iata IS null") | |
.toDF.as[City].cache |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Parsing JSON to Map using CustomSerializer | |
// == JSON format == | |
// | |
// { | |
// "AAA": { | |
// "latitude": 1.234567, | |
// "longitude": -1.234567 | |
// }, | |
// ... | |
// } | |
import org.json4s.JsonAST.JObject | |
import org.json4s.{CustomSerializer, DefaultFormats} | |
import org.json4s.native.JsonMethods.parse | |
import org.json4s.JsonDSL._ | |
import java.nio.file.Files | |
import java.nio.file.Paths | |
implicit val formats = DefaultFormats + new LocationSerializer() | |
case class Location(longitude: Double, latitude: Double) | |
class LocationSerializer extends CustomSerializer[Location](format => ({ | |
case obj: JObject => | |
println(obj) | |
val longitude = (obj \ "longitude").values.asInstanceOf[Double] | |
val latitude = (obj \ "latitude").values.asInstanceOf[Double] | |
Location(longitude, latitude) | |
}, { | |
case obj: Location => | |
("longitude" -> obj.longitude) ~ | |
("latitude" -> obj.latitude) | |
} | |
)) | |
val jsonString = new String(Files.readAllBytes(Paths.get(s"${Settings.trends.rootDataFolder}/locations_cache.json"))) | |
parse(jsonString).extract[Map[String, Location]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Parse JSON using Spark DataFrames | |
// == JSON format == | |
// | |
// { "iata": "AAA", "location": { "lon": 123.456, "lat": -123.456 }}, | |
// { "iata": "BBB", "location": { "lon": 123.456, "lat": -123.456 }}, | |
// ... | |
import org.apache.spark.sql.SparkSession | |
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema | |
case class Location(longitude: Double, latitude: Double) | |
case class IataLocation(iata: String, location: Location) | |
val spark = SparkSession | |
.builder | |
.appName("Spark Test") | |
.master("local[*]") | |
.getOrCreate | |
val locationCache = spark | |
.read | |
.json(s"file://${Settings.trends.rootDataFolder}/locations-cache.json") | |
.map((item) => | |
IataLocation(iata = item.getAs("iata"), Location( | |
longitude = item.getAs("location").asInstanceOf[GenericRowWithSchema].getAs("lon"), | |
latitude = item.getAs("location").asInstanceOf[GenericRowWithSchema].getAs("lat") | |
)) | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// A typical standard machine learning workflow is as follows: | |
// * Loading data (aka data ingestion) | |
// * Extracting features (aka feature extraction) | |
// * Training model (aka model training) | |
// * Evaluate (or predictionize) | |
// ETL | |
val loadedDataframe = ... // get data from source(s) | |
val labeledPoints = ... // pack the data into a Dataset (typed) case class with label and features fields | |
// Split the data into training and test sets (30% held out for testing). | |
val Array(train, test) = labeledPoints.randomSplit(Array(.7, .3), 10204L) | |
// Scale and normalize feature based on their individual variances | |
val featureScaler = new StandardScaler() | |
.setInputCol("features") | |
.setOutputCol("scaledFeatures") | |
.setWithStd(true) | |
.setWithMean(false) | |
// Automatically identify categorical features, and index them. | |
val featureIndexer = new VectorIndexer() | |
.setInputCol("scaledFeatures") | |
.setOutputCol("indexedFeatures") | |
// Train a RandomForest model. | |
val rf = new RandomForestRegressor() | |
.setLabelCol("label") | |
.setFeaturesCol("indexedFeatures") | |
val pipeline = new Pipeline("MyModel") | |
.setStages(Array(vectorAssembler, featureScaler, featureIndexer, rf)) | |
// Train model. This runs the input data through the pipeline, ending with RandomForestRegressor | |
val randomForestModel = pipeline.fit(train) | |
// Make predictions. | |
val predictions = randomForestModel.transform(test) | |
// Print first 10 labels and predictions | |
predictions.select("prediction", "label").show(10) | |
// Evaluate the pipeline prediction accuracy | |
val evaluator = new RegressionEvaluator | |
val r2 = evaluator.setMetricName("r2").evaluate(predictions) | |
val rmse = evaluator.setMetricName("rmse").evaluate(predictions) | |
val mse = evaluator.setMetricName("mse").evaluate(predictions) | |
val mae = evaluator.setMetricName("mae").evaluate(predictions) | |
println(s"Mean Squared Error: ${mse}") | |
println(s"Root Mean Squared Error: ${rmse}") | |
println(s"Coefficient of Determination R-squared: ${r2}") | |
println(s"Mean Absoloute Error: ${mae}") | |
println(s"Explained params: ${evaluator.explainParams}") | |
// Serialize the pipeline to | |
pipeline.write.overwrite.save(s"${somePath}/MyPipeline") | |
// Deserialize the pipeline | |
val deserializedPipeline = PipelineModel.load(s"${somePath}/MyPipeline") | |
// Like on line 39, now we can use the loaded pipeline to make predictions. | |
val predictions = deserializedPipeline.transform(test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment