Skip to content

Instantly share code, notes, and snippets.

@hakanilter
Last active November 22, 2024 13:31
Show Gist options
  • Save hakanilter/b60bfc570e865243b361a3548cfaeef3 to your computer and use it in GitHub Desktop.
Save hakanilter/b60bfc570e865243b361a3548cfaeef3 to your computer and use it in GitHub Desktop.
Spark JDBC Import Over SSH
import os
import json
import boto3
import base64
from sshtunnel import SSHTunnelForwarder
"""
This Python class, **`JDBCImportOverSSH`**, facilitates the transfer of data from a JDBC-compatible source to a Delta table via an SSH tunnel. Here's an overview:
1. **Configuration Handling**:
- Retrieves JDBC and SSH configuration details from AWS Secrets Manager.
- Manages the SSH private key securely by writing it to a temporary file and deleting it after use.
2. **SSH Tunnel Setup**:
- Establishes an SSH tunnel to securely connect to the database through a remote host.
3. **Data Import Logic**:
- For each table in the import configuration:
- Determines the bounds for partitioning based on a specified column.
- Reads the data from the source table over JDBC with partitioning to optimize performance.
- Writes the data to the target table in Delta format.
4. **Key Features**:
- Uses Apache Spark for JDBC operations and Delta table writes.
- Ensures secure connections with SSH and AWS Secrets Manager.
- Automatically cleans up resources (SSH key file) after execution.
This class is useful for securely migrating or synchronizing data from remote databases to Delta Lake in distributed environments.
Sample Import Config:
```json
{
"my source": [
{
"source_table": "my_schema.my_source_table",
"target_table": "some_catalog.target_schema.target_table",
"partition_column": "id",
"num_partitions": 4
}
]
}
```
Sample JDBC Config:
```json
{
"host": "db_host",
"port": 5439,
"database": "my_database",
"user": "my_user",
"pass": "my_pass",
"driver": "com.amazon.redshift.jdbc42.Driver"
}
```
Sample SSH Config:
```json
{
"host": "10.12.8.9",
"port": 22,
"user": "ubuntu",
"key": "<base64-encoded-pem-file-content>"
}
```
"""
class JDBCImportOverSSH:
def run(self, import_config, jdbc_secret, ssh_secret):
self.import_config = import_config
# Get connection details from secrets manager
self.jdbc_config = json.loads(self._get_secret(jdbc_secret))
# Get SSH config
self.ssh_config = json.loads(self._get_secret(ssh_secret))
try:
self._create_ssh_key_file()
self._import_all()
finally:
self._delete_ssh_key_file()
def _import_all(self):
# Connect to the wireless controller through the SSH tunnel
with SSHTunnelForwarder(
(self.ssh_config["host"], self.ssh_config["port"]),
ssh_username=self.ssh_config["user"],
ssh_pkey=self.ssh_file,
remote_bind_address=(self.jdbc_config["host"], self.jdbc_config["port"])
) as tunnel:
# Import all
for config in self.import_config:
print(f"Start importing: {config}")
self._import_table(
tunnel,
config["source_table"],
config["target_table"],
config["partition_column"],
config["num_partitions"])
def _import_table(self, tunnel, source_table, target_table, partition_column, num_partitions):
redshift_db = self.jdbc_config["database"]
redshift_url = f"jdbc:redshift://localhost:{tunnel.local_bind_address[1]}/{redshift_db}?ssl=True"
redshift_properties = {
"user": self.jdbc_config["user"],
"password": self.jdbc_config["pass"],
"driver": self.jdbc_config["driver"]
}
# Find upper and lower bounds
query = f"""
(SELECT
MIN({partition_column}) AS lower_bound,
MAX({partition_column}) AS upper_bound
FROM
{source_table}) AS redshift_table
"""
df = spark.read.jdbc(url=redshift_url, table=query, properties=redshift_properties)
row = df.collect()[0]
lower_bound = row.lower_bound
upper_bound = row.upper_bound
# Read the table
query = f"(SELECT * FROM {source_table}) AS redshift_table"
df = spark.read \
.option("numPartitions", num_partitions) \
.option("partitionColumn", partition_column) \
.option("lowerBound", lower_bound) \
.option("upperBound", upper_bound) \
.jdbc(url=redshift_url, table=query, properties=redshift_properties)
# Write as a Delta table
df.write.mode("overwrite") \
.format("delta") \
.saveAsTable(target_table)
def _create_ssh_key_file(self):
self.ssh_file = "/tmp/ssh_key"
with open(self.ssh_file, "w") as f:
content = base64.b64decode(self.ssh_config["key"]).decode("ascii")
f.write(content)
def _delete_ssh_key_file(self):
try:
os.remove(self.ssh_file)
except Exception as e:
print(e)
def _get_secret(self, secret_name, region_name="eu-west-1"):
session = boto3.session.Session()
client = session.client(service_name="secretsmanager", region_name=region_name)
response = client.get_secret_value(SecretId=secret_name)
return response["SecretString"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment