Last active
November 22, 2024 13:31
-
-
Save hakanilter/b60bfc570e865243b361a3548cfaeef3 to your computer and use it in GitHub Desktop.
Spark JDBC Import Over SSH
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import boto3 | |
import base64 | |
from sshtunnel import SSHTunnelForwarder | |
""" | |
This Python class, **`JDBCImportOverSSH`**, facilitates the transfer of data from a JDBC-compatible source to a Delta table via an SSH tunnel. Here's an overview: | |
1. **Configuration Handling**: | |
- Retrieves JDBC and SSH configuration details from AWS Secrets Manager. | |
- Manages the SSH private key securely by writing it to a temporary file and deleting it after use. | |
2. **SSH Tunnel Setup**: | |
- Establishes an SSH tunnel to securely connect to the database through a remote host. | |
3. **Data Import Logic**: | |
- For each table in the import configuration: | |
- Determines the bounds for partitioning based on a specified column. | |
- Reads the data from the source table over JDBC with partitioning to optimize performance. | |
- Writes the data to the target table in Delta format. | |
4. **Key Features**: | |
- Uses Apache Spark for JDBC operations and Delta table writes. | |
- Ensures secure connections with SSH and AWS Secrets Manager. | |
- Automatically cleans up resources (SSH key file) after execution. | |
This class is useful for securely migrating or synchronizing data from remote databases to Delta Lake in distributed environments. | |
Sample Import Config: | |
```json | |
{ | |
"my source": [ | |
{ | |
"source_table": "my_schema.my_source_table", | |
"target_table": "some_catalog.target_schema.target_table", | |
"partition_column": "id", | |
"num_partitions": 4 | |
} | |
] | |
} | |
``` | |
Sample JDBC Config: | |
```json | |
{ | |
"host": "db_host", | |
"port": 5439, | |
"database": "my_database", | |
"user": "my_user", | |
"pass": "my_pass", | |
"driver": "com.amazon.redshift.jdbc42.Driver" | |
} | |
``` | |
Sample SSH Config: | |
```json | |
{ | |
"host": "10.12.8.9", | |
"port": 22, | |
"user": "ubuntu", | |
"key": "<base64-encoded-pem-file-content>" | |
} | |
``` | |
""" | |
class JDBCImportOverSSH: | |
def run(self, import_config, jdbc_secret, ssh_secret): | |
self.import_config = import_config | |
# Get connection details from secrets manager | |
self.jdbc_config = json.loads(self._get_secret(jdbc_secret)) | |
# Get SSH config | |
self.ssh_config = json.loads(self._get_secret(ssh_secret)) | |
try: | |
self._create_ssh_key_file() | |
self._import_all() | |
finally: | |
self._delete_ssh_key_file() | |
def _import_all(self): | |
# Connect to the wireless controller through the SSH tunnel | |
with SSHTunnelForwarder( | |
(self.ssh_config["host"], self.ssh_config["port"]), | |
ssh_username=self.ssh_config["user"], | |
ssh_pkey=self.ssh_file, | |
remote_bind_address=(self.jdbc_config["host"], self.jdbc_config["port"]) | |
) as tunnel: | |
# Import all | |
for config in self.import_config: | |
print(f"Start importing: {config}") | |
self._import_table( | |
tunnel, | |
config["source_table"], | |
config["target_table"], | |
config["partition_column"], | |
config["num_partitions"]) | |
def _import_table(self, tunnel, source_table, target_table, partition_column, num_partitions): | |
redshift_db = self.jdbc_config["database"] | |
redshift_url = f"jdbc:redshift://localhost:{tunnel.local_bind_address[1]}/{redshift_db}?ssl=True" | |
redshift_properties = { | |
"user": self.jdbc_config["user"], | |
"password": self.jdbc_config["pass"], | |
"driver": self.jdbc_config["driver"] | |
} | |
# Find upper and lower bounds | |
query = f""" | |
(SELECT | |
MIN({partition_column}) AS lower_bound, | |
MAX({partition_column}) AS upper_bound | |
FROM | |
{source_table}) AS redshift_table | |
""" | |
df = spark.read.jdbc(url=redshift_url, table=query, properties=redshift_properties) | |
row = df.collect()[0] | |
lower_bound = row.lower_bound | |
upper_bound = row.upper_bound | |
# Read the table | |
query = f"(SELECT * FROM {source_table}) AS redshift_table" | |
df = spark.read \ | |
.option("numPartitions", num_partitions) \ | |
.option("partitionColumn", partition_column) \ | |
.option("lowerBound", lower_bound) \ | |
.option("upperBound", upper_bound) \ | |
.jdbc(url=redshift_url, table=query, properties=redshift_properties) | |
# Write as a Delta table | |
df.write.mode("overwrite") \ | |
.format("delta") \ | |
.saveAsTable(target_table) | |
def _create_ssh_key_file(self): | |
self.ssh_file = "/tmp/ssh_key" | |
with open(self.ssh_file, "w") as f: | |
content = base64.b64decode(self.ssh_config["key"]).decode("ascii") | |
f.write(content) | |
def _delete_ssh_key_file(self): | |
try: | |
os.remove(self.ssh_file) | |
except Exception as e: | |
print(e) | |
def _get_secret(self, secret_name, region_name="eu-west-1"): | |
session = boto3.session.Session() | |
client = session.client(service_name="secretsmanager", region_name=region_name) | |
response = client.get_secret_value(SecretId=secret_name) | |
return response["SecretString"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment