Created
April 13, 2025 18:17
-
-
Save zcking/9fc46dce43d71a98a7effaefd2b15f4d to your computer and use it in GitHub Desktop.
Example of a custom pyspark data source for ingesting satellite TLE data and visualizing it with matplotlib.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from pyspark.sql.datasource import DataSource, DataSourceStreamReader, DataSourceReader, InputPartition | |
from pyspark.sql.types import StructType | |
from pyspark.sql import SparkSession, DataFrame | |
import matplotlib.cm as cm | |
import requests | |
from datetime import datetime, timedelta | |
from sgp4.api import Satrec, jday, SGP4_ERRORS | |
from astropy.time import Time | |
from astropy.coordinates import TEME, ITRS | |
from astropy import units | |
def teme_to_ecef(rx, ry, rz, jd, fr): | |
"""Convert TEME to ECEF coordinate system""" | |
teme_position = TEME( | |
rx * units.km, | |
ry * units.km, | |
rz * units.km, | |
obstime=Time(jd + fr, format="jd") | |
) | |
ecef_position = teme_position.transform_to(ITRS(obstime=teme_position.obstime)) | |
return ecef_position.x.value, ecef_position.y.value, ecef_position.z.value | |
class SatelliteDataSourceReader(DataSourceReader): | |
"""Batch reader for satellite TLE data.""" | |
def __init__(self, schema: StructType, options: dict): | |
self.options = options | |
self.norad_id = options.get("norad_id") | |
self.size = int(options.get("size", 5)) | |
self.timedelta = options.get("timedelta", "seconds=1").split("=") | |
self.timedelta = {self.timedelta[0]: int(self.timedelta[1])} | |
self.api_key = options.get("api_key", "DEMO_KEY") | |
# Fetch the TLE data for the given satellite | |
tle = self.fetch_tle_data(self.norad_id, self.api_key) | |
self.tle_line1 = tle["line1"] | |
self.tle_line2 = tle["line2"] | |
self.name = tle["name"] | |
self.tle_timestamp = tle["date"] | |
self.tle_timestamp = datetime.strptime(self.tle_timestamp, "%Y-%m-%dT%H:%M:%S%z") # Example: "2025-04-09T12:28:54+00:00" | |
@staticmethod | |
def fetch_tle_data(norad_id: str, api_key: str) -> dict: | |
headers = {"Accept": "*/*", "User-Agent": "curl"} | |
url = f"https://tle.ivanstanojevic.me/api/tle/{norad_id}" | |
response = requests.get(url, headers=headers, params={"api_key": api_key}) | |
response.raise_for_status() | |
return response.json() | |
def read(self, partition): | |
# Note: library imports must be within the method. | |
from sgp4.api import Satrec, jday, SGP4_ERRORS | |
from datetime import datetime, timedelta | |
satellite = Satrec.twoline2rv( | |
self.tle_line1, | |
self.tle_line2 | |
) | |
for i in range(self.size): | |
ts = datetime.now() + (timedelta(**self.timedelta) * i) | |
jd, fr = jday( | |
ts.year, | |
ts.month, | |
ts.day, | |
ts.hour, | |
ts.minute, | |
ts.second, | |
) | |
e, r, v = satellite.sgp4(jd, fr) | |
if e != 0: | |
e = SGP4_ERRORS.get(e, 'Unknown error') | |
r = {"x": r[0], "y": r[1], "z": r[2]} | |
v = {"x": v[0], "y": v[1], "z": v[2]} | |
yield (ts, r, v, e, self.norad_id, self.name, self.tle_timestamp,self.tle_line1, self.tle_line2, jd, fr) | |
class SatellitePartition(InputPartition): | |
"""A partition for satellite TLE data.""" | |
def __init__(self, norad_id: str, name: str,tle_line1: str, tle_line2: str, tle_timestamp: datetime, size: int): | |
self.norad_id = norad_id | |
self.name = name | |
self.tle_line1 = tle_line1 | |
self.tle_line2 = tle_line2 | |
self.tle_timestamp = tle_timestamp | |
self.size = size | |
class SatelliteStreamReader(DataSourceStreamReader): | |
"""A stream reader for satellite TLE data.""" | |
def __init__(self, schema: StructType, options: dict): | |
self.norad_id = options.get("norad_id").split(",") | |
self.size = int(options.get("size", 5)) | |
self.timedelta = options.get("timedelta", "seconds=1").split("=") | |
self.timedelta = {self.timedelta[0]: int(self.timedelta[1])} | |
self.api_key = options.get("api_key", "DEMO_KEY") | |
def initialOffset(self): | |
""" | |
Return the initial offset of the streaming data source. | |
A new streaming query starts reading data from the initial offset. | |
If Spark is restarting an existing query, it will restart from the check-pointed offset | |
rather than the initial one. | |
Returns | |
------- | |
dict | |
A dict or recursive dict whose key and value are primitive types, which includes | |
Integer, String and Boolean. | |
""" | |
return {} | |
def latestOffset(self) -> dict: | |
""" | |
Returns the most recent offset available. | |
Returns | |
------- | |
dict | |
A dict or recursive dict whose key and value are primitive types, which includes | |
Integer, String and Boolean. | |
""" | |
# Fetch the TLE data for each of the given satellites | |
offsets = {} | |
for norad_id in self.norad_id: | |
tle = SatelliteDataSourceReader.fetch_tle_data(norad_id, self.api_key) | |
tle_line1 = tle["line1"] | |
tle_line2 = tle["line2"] | |
name = tle["name"] | |
tle_timestamp = tle["date"] | |
offsets[norad_id] = {"name": name, "tle_line1": tle_line1, "tle_line2": tle_line2, "tle_timestamp": tle_timestamp} | |
return offsets | |
def partitions(self, start: dict, end: dict): | |
partitions = [] | |
for id, val in end.items(): | |
if id not in start: | |
partitions.append(SatellitePartition(id, val["name"], val["tle_line1"], val["tle_line2"], val["tle_timestamp"], self.size)) | |
else: | |
if start[id]["tle_line1"] != val["tle_line1"] or start[id]["tle_line2"] != val["tle_line2"] or start[id]["tle_timestamp"] != val["tle_timestamp"]: | |
partitions.append(SatellitePartition(id, val["name"], val["tle_line1"], val["tle_line2"], val["tle_timestamp"], self.size)) | |
return partitions | |
def read(self, partition: SatellitePartition): | |
from datetime import datetime, timedelta | |
from sgp4.api import Satrec, jday, SGP4_ERRORS | |
satellite = Satrec.twoline2rv( | |
partition.tle_line1, | |
partition.tle_line2 | |
) | |
for i in range(partition.size): | |
ts = datetime.now() + (timedelta(**self.timedelta) * i) | |
jd, fr = jday( | |
ts.year, | |
ts.month, | |
ts.day, | |
ts.hour, | |
ts.minute, | |
ts.second, | |
) | |
e, r, v = satellite.sgp4(jd, fr) | |
if e != 0: | |
e = SGP4_ERRORS.get(e, 'Unknown error') | |
r = {"x": r[0], "y": r[1], "z": r[2]} | |
v = {"x": v[0], "y": v[1], "z": v[2]} | |
yield ( | |
ts, | |
r, | |
v, | |
e, | |
partition.norad_id, | |
partition.name, | |
datetime.strptime(partition.tle_timestamp, "%Y-%m-%dT%H:%M:%S%z"), | |
partition.tle_line1, | |
partition.tle_line2, | |
jd, | |
fr | |
) | |
def simple_reader(self, start): | |
start_idx = start["offset"] | |
rows = [] | |
satellite = Satrec.twoline2rv( | |
self.tle_line1, | |
self.tle_line2 | |
) | |
for i in range(start_idx, start_idx + self.size): | |
ts = datetime.now() + timedelta(seconds=i) | |
jd, fr = jday( | |
ts.year, | |
ts.month, | |
ts.day, | |
ts.hour, | |
ts.minute, | |
ts.second, | |
) | |
e, r, v = satellite.sgp4(jd, fr) | |
if e != 0: | |
e = SGP4_ERRORS.get(e, 'Unknown error') | |
r = {"x": r[0], "y": r[1], "z": r[2]} | |
v = {"x": v[0], "y": v[1], "z": v[2]} | |
rows.append(( | |
ts, | |
r, | |
v, | |
e, | |
self.norad_id, | |
self.name, | |
self.tle_timestamp, | |
self.tle_line1, | |
self.tle_line2, | |
jd, | |
fr | |
)) | |
return (iter(rows), {"offset": start_idx + self.size}) | |
class SatelliteDataSource(DataSource): | |
""" | |
A data source for satellite data. | |
""" | |
@classmethod | |
def name(cls): | |
return "satellite" | |
def schema(self): | |
return """ | |
ts TIMESTAMP, | |
pos STRUCT< | |
x:DOUBLE, | |
y:DOUBLE, | |
z:DOUBLE | |
>, | |
velocity STRUCT< | |
x:DOUBLE, | |
y:DOUBLE, | |
z:DOUBLE | |
>, | |
e STRING, | |
norad_id STRING, | |
name STRING, | |
tle_timestamp TIMESTAMP, | |
tle_line1 STRING, | |
tle_line2 STRING, | |
jd DOUBLE, | |
fr DOUBLE | |
""" | |
def streamReader(self, schema: StructType): | |
return SatelliteStreamReader(schema, self.options) | |
def reader(self, schema: StructType): | |
return SatelliteDataSourceReader(schema, self.options) | |
def visualize(df: DataFrame): | |
matplotlib.use('Agg') | |
pDF = df.toPandas() | |
pDF["color_id"] = pd.factorize(pDF["norad_id"])[0] | |
# Generate a color map for unique norad_ids | |
unique_ids = pDF["norad_id"].unique() | |
id_to_color = { | |
norad_id: cm.viridis(i / len(unique_ids)) | |
for i, norad_id in enumerate(unique_ids) | |
} | |
# Convert TEME to ECEF | |
pDF["x_ecef"], pDF["y_ecef"], pDF["z_ecef"] = zip(*pDF.apply(lambda row: teme_to_ecef( | |
row["pos"]["x"], row["pos"]["y"], row["pos"]["z"], row["jd"], row["fr"] | |
), axis=1)) | |
# Plotting | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
ax.set_xlim([-7000, 7000]) | |
ax.set_ylim([-7000, 7000]) | |
ax.set_zlim([-7000, 7000]) | |
ax.set_xlabel('X (km)') | |
ax.set_ylabel('Y (km)') | |
ax.set_zlabel('Z (km)') | |
# Plot the Earth | |
u, v = np.mgrid[0:2*np.pi:100j, 0:np.pi:50j] | |
x = 6371 * np.cos(u) * np.sin(v) | |
y = 6371 * np.sin(u) * np.sin(v) | |
z = 6371 * np.cos(v) | |
ax.plot_wireframe(x, y, z, color='blue', alpha=0.1) | |
# Plot the satellites with different colors for each one | |
for norad_id, g in pDF.groupby('norad_id'): | |
ax.scatter( | |
g["x_ecef"], | |
g["y_ecef"], | |
g["z_ecef"], | |
color=id_to_color[norad_id], | |
s=10, | |
label=f"Satellite {norad_id}" | |
) | |
ax.legend() | |
# plt.show() # Uncomment to display the plot in a notebook | |
plt.savefig("satellite_plot.png") | |
if __name__ == "__main__": | |
spark = SparkSession.builder.appName("SatelliteStream").getOrCreate() | |
spark.dataSource.register(SatelliteDataSource) | |
df = ( | |
spark | |
.readStream | |
.format("satellite") | |
.option("norad_id", "25544,33499,46362") | |
.option("size", "240") # 240 predictions | |
.option("timedelta", "minutes=1") # 1-minute predictions. 240 x 1 minute = 240 minutes = 4 hours | |
.option("api_key", "DEMO_KEY") | |
.load() | |
) | |
# Processing the stream | |
q = ( | |
df | |
.writeStream | |
.outputMode("append") | |
.trigger(once=True) | |
.foreachBatch(lambda df, batch_id: visualize(df)) | |
.start() | |
.awaitTermination() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment