Created
October 24, 2024 22:08
-
-
Save film42/64d4835fe37b15a8f889f9bc61325b08 to your computer and use it in GitHub Desktop.
Write a dataframe to BigQuery using the Storage Write API with Flexible Columns with dynamic protos
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE: The BigQueryTable class is something we wrote to wrap a bunch of operations and migrate tables, | |
# but you can catch the gist of what we're doing. Converting types is not the hard part. | |
# | |
# NOTE: The protobuf descritpor comes from: https://github.com/googleapis/googleapis/blob/master/google/cloud/bigquery/storage/v1/annotations.proto | |
import big_query_storage_write_api as s | |
s.to_gbq( | |
df=df, | |
table=internal.BigQueryTable("some_table"), | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from google.cloud import bigquery | |
import pandas as pd | |
import banzai.google as google | |
from datetime import datetime | |
import json | |
import numpy as np | |
from banzai.google.big_query_table import BigQueryTable | |
from google.protobuf import descriptor_pb2, descriptor_pool, message_factory | |
from google.cloud import bigquery_storage_v1 | |
from google.cloud.bigquery_storage_v1 import types, writer | |
from google.api_core.exceptions import InvalidArgument | |
import re | |
import hashlib | |
def generate_protobuf_safe_column_name(column_name): | |
# Only keep alpha numeric | |
masked = re.sub(r"[^a-zA-Z0-9]", "_", column_name) | |
# Create a SHA1 hash of the original string to avoid duplicate column names after | |
# scrubbing alpha-numeric filtering. | |
hash_obj = hashlib.sha1(column_name.encode()) | |
hash_str = hash_obj.hexdigest()[:10] | |
# Combine masked string with hash | |
return f"column__{masked}_{hash_str}" | |
# Import bigquery storage annotations so we can set the column_name to a flexible column name. | |
# See the code example in docs: https://cloud.google.com/bigquery/docs/schemas#flexible-column-names | |
# | |
# Load the bigquery storage annotations into the default descriptor pool. | |
import banzai.google.google_cloud_bigquery_storage_v1_annotations | |
def get_column_name_extension(): | |
return descriptor_pool.Default().FindExtensionByName( | |
"google.cloud.bigquery.storage.v1.column_name" | |
) | |
# HACK: I have to monkey patch AppendRowsStream until the following issue is closed | |
# and the response is accesssible from an append rows request exception. | |
# | |
# GH: https://github.com/googleapis/python-bigquery-storage/issues/836 | |
# | |
def _monkey_patch_writer_append_rows_stream(w: writer.AppendRowsStream): | |
from google.api_core import exceptions | |
from google.cloud.bigquery_storage_v1 import exceptions as bqstorage_exceptions | |
# Take from my patch here: https://github.com/googleapis/python-bigquery-storage/pull/838 | |
def _on_response(self, response: types.AppendRowsResponse): | |
"""Process a response from a consumer callback.""" | |
# If the stream has closed, but somehow we still got a response message | |
# back, discard it. The response futures queue has been drained, with | |
# an exception reported. | |
if self._closed: | |
raise bqstorage_exceptions.StreamClosedError( | |
f"Stream closed before receiving response: {response}" | |
) | |
# Since we have 1 response per request, if we get here from a response | |
# callback, the queue should never be empty. | |
future: writer.AppendRowsFuture = self._futures_queue.get_nowait() | |
if response.error.code: | |
exc = exceptions.from_grpc_status( | |
response.error.code, response.error.message, response=response | |
) | |
future.set_exception(exc) | |
else: | |
future.set_result(response) | |
w._on_response = _on_response.__get__(w, type(w)) | |
class BigQueryDirectStorageCopyException(Exception): | |
pass | |
class BigQueryDirectStorageCopyBadResponseException(Exception): | |
def __init__( | |
self, | |
exception, | |
flexible_to_strict_column_names={}, | |
proto_class="DynamicMessage", | |
): | |
self.exception = exception | |
self.response = exception.response | |
m = f"{self.exception}\nResponse:\n{self.response}" | |
# A bit of a hack to clean up the error messages so they show the flexible column and not the | |
# generated proto friendly column. | |
for ( | |
flexible_col_name, | |
strict_col_name, | |
) in flexible_to_strict_column_names.items(): | |
m = m.replace(f"{proto_class}.{strict_col_name}", flexible_col_name) | |
self.error_message = m | |
super().__init__(m) | |
def to_gbq( | |
df: pd.DataFrame, table: BigQueryTable, batch_size=501, max_concurrent_futures=200 | |
): | |
_append_df_to_bq_table( | |
df, | |
table=table, | |
batch_size=batch_size, | |
max_concurrent_futures=max_concurrent_futures, | |
) | |
# NOTE: I used a stackoverflow article to write a bunch of this code: | |
# https://stackoverflow.com/questions/77428218/creating-a-protobuf-factory-for-a-dynamically-generated-message | |
def _create_dynamic_pb_for_schema(table_schema, selected_columns): | |
message_descriptor_proto = descriptor_pb2.DescriptorProto() | |
message_descriptor_proto.name = "DynamicMessage" | |
field_num = 0 | |
column_name_to_proto_field_name = {} | |
for column_name, schema_field in table_schema.items(): | |
# Skip anything that's not in the table and the dataframe. | |
if column_name not in selected_columns: | |
continue | |
field_num += 1 | |
proto_field_name = generate_protobuf_safe_column_name(column_name) | |
column_name_to_proto_field_name[column_name] = proto_field_name | |
# NOTE: The documentation for mapping big query types into protobuf types is found here: | |
# https://cloud.google.com/bigquery/docs/write-api#data_type_conversions | |
if ( | |
schema_field.field_type == "STRING" | |
or schema_field.field_type == "DATE" | |
or schema_field.field_type == "DATETIME" | |
or schema_field.field_type == "JSON" | |
or schema_field.field_type == "TIME" | |
): | |
field = message_descriptor_proto.field.add( | |
name=proto_field_name, | |
number=field_num, | |
type=descriptor_pb2.FieldDescriptorProto.TYPE_STRING, | |
) | |
elif schema_field.field_type == "FLOAT": | |
field = message_descriptor_proto.field.add( | |
name=proto_field_name, | |
number=field_num, | |
type=descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE, | |
) | |
elif ( | |
schema_field.field_type == "INTEGER" | |
or schema_field.field_type == "TIMESTAMP" | |
): | |
field = message_descriptor_proto.field.add( | |
name=proto_field_name, | |
number=field_num, | |
type=descriptor_pb2.FieldDescriptorProto.TYPE_INT64, | |
) | |
elif schema_field.field_type == "BOOLEAN": | |
field = message_descriptor_proto.field.add( | |
name=proto_field_name, | |
number=field_num, | |
type=descriptor_pb2.FieldDescriptorProto.TYPE_BOOL, | |
) | |
else: | |
field_type = schema_field.field_type | |
raise BigQueryDirectStorageCopyException( | |
f"Found unmapped field_type: {field_type}" | |
) | |
ext_column_name = get_column_name_extension() | |
# Set the BigQuery column_name annotation which takes priority. | |
field_options = descriptor_pb2.FieldOptions() | |
field_options.Extensions[ext_column_name] = column_name | |
field.options.CopyFrom(field_options) | |
file_descriptor_proto = descriptor_pb2.FileDescriptorProto() | |
file_descriptor_proto.name = "dynamic_message.proto" | |
file_descriptor_proto.package = "dynamic_package" | |
file_descriptor_proto.message_type.add().CopyFrom(message_descriptor_proto) | |
pool = descriptor_pool.DescriptorPool() | |
file_descriptor = pool.Add(file_descriptor_proto) | |
message_class = message_factory.GetMessages([file_descriptor_proto])[ | |
"dynamic_package.DynamicMessage" | |
] | |
return message_class, column_name_to_proto_field_name | |
def _type_cast_column(column_name, new_type, value): | |
try: | |
return new_type(value) | |
except ValueError as e: | |
raise ValueError( | |
f'Error casting type {value.__class__.__name__} to {new_type.__name__} for column "{column_name}": {e}' | |
) | |
def _create_rows_from_df( | |
df, msg_class, table_schema, selected_columns, column_name_to_proto_field_name | |
): | |
proto_rows = types.ProtoRows() | |
for _, row in df.iterrows(): | |
msg = msg_class() | |
for column_name, value in row.items(): | |
# Skip the non-selected columns | |
if column_name not in selected_columns: | |
continue | |
# Skip nan values since those are implicit in protobuf. | |
if pd.isna(value): | |
continue | |
# We need to cast date/datetime to an iso 8601 string for BQ. | |
if ( | |
table_schema[column_name].field_type == "DATE" | |
or table_schema[column_name].field_type == "DATETIME" | |
): | |
if isinstance(value, datetime): | |
# Drop the timezone so that any datetime string will not have "+00:00" appended | |
# since it's not a valid datetime literal in BigQuery. | |
value = value.replace(tzinfo=None) | |
if hasattr(value, "isoformat"): | |
value = value.isoformat() | |
if table_schema[column_name].field_type == "FLOAT": | |
value = _type_cast_column(column_name, float, value) | |
# Handle time -> string conversion for datetime and time types. | |
if table_schema[column_name].field_type == "TIME": | |
# If we received a datetime, we should drop the date and only keep the time part. | |
if isinstance(value, datetime): | |
value = value.time() | |
if hasattr(value, "isoformat"): | |
value = value.isoformat() | |
if table_schema[column_name].field_type == "INTEGER": | |
value = _type_cast_column(column_name, int, value) | |
# Handle timestamp -> int64 conversion for datetime types. | |
if table_schema[column_name].field_type == "TIMESTAMP": | |
if isinstance(value, datetime): | |
# We want a unix timestamp in microseconds. | |
value = int(datetime.timestamp(value) * 1000000) | |
# Handle dumping a json string if we don't detect a string present. | |
if table_schema[column_name].field_type == "JSON": | |
if not isinstance(value, str): | |
value = json.dumps(value) | |
if table_schema[column_name].field_type == "STRING": | |
value = _type_cast_column(column_name, str, value) | |
# Set the attribute on the proto | |
try: | |
setattr(msg, column_name_to_proto_field_name[column_name], value) | |
except TypeError as e: | |
# The lone type error is not helpful, so let's re-wrap and re-raise. | |
raise TypeError(f"Error setting {column_name} on proto for upload: {e}") | |
proto_rows.serialized_rows.append(msg.SerializeToString()) | |
return proto_rows | |
def _await_futures(futures, column_name_to_proto_field_name): | |
for future in futures: | |
try: | |
future.result() | |
except InvalidArgument as e: | |
raise BigQueryDirectStorageCopyBadResponseException( | |
e, flexible_to_strict_column_names=column_name_to_proto_field_name | |
) | |
def _append_df_to_bq_table( | |
df: pd.DataFrame, table, batch_size=500, max_concurrent_futures=200 | |
): | |
table_schema = {s.name: s for s in table.schema_fields()} | |
table_columns = set(table_schema.keys()) | |
# We only want to select columns that are both in the table and in the dataframe. | |
# NOTE: This probably already happens when enforce_schema is used but just being a little safe. | |
columns = table_columns.intersection(set(df.columns)) | |
proto_schema = types.ProtoSchema() | |
# Convert the DF into a series of serialized rows using a dynamic proto | |
msg_class, column_name_to_proto_field_name = _create_dynamic_pb_for_schema( | |
table_schema=table_schema, selected_columns=columns | |
) | |
# NOTE: The batch size cannot exceed 10MB. | |
df_chunks = [df[i : i + batch_size] for i in range(0, len(df), batch_size)] | |
# NOTE: Here are some useful links to review this code: | |
# | |
# https://github.com/googleapis/python-bigquery-storage | |
# The bigquery storage client codebase is useful to understand what's happening behind each call. | |
# | |
# https://cloud.google.com/bigquery/docs/write-api-batch | |
# Example impl of a python batch write client | |
# Start streaming the data | |
write_client = bigquery_storage_v1.BigQueryWriteClient() | |
parent = write_client.table_path( | |
table.get_project_id(), table.get_dataset_id(), table.name | |
) | |
write_stream = types.WriteStream() | |
# Set up the pending stream for batch import | |
write_stream.type_ = types.WriteStream.Type.PENDING | |
write_stream = write_client.create_write_stream( | |
parent=parent, write_stream=write_stream | |
) | |
stream_name = write_stream.name | |
request_template = types.AppendRowsRequest() | |
request_template.write_stream = stream_name | |
# Wire up the descriptor and add the row data | |
proto_schema = types.ProtoSchema() | |
proto_descriptor = descriptor_pb2.DescriptorProto() | |
msg_class.DESCRIPTOR.CopyToProto(proto_descriptor) | |
proto_schema.proto_descriptor = proto_descriptor | |
proto_data = types.AppendRowsRequest.ProtoData() | |
proto_data.writer_schema = proto_schema | |
request_template.proto_rows = proto_data | |
append_rows_stream = writer.AppendRowsStream(write_client, request_template) | |
# HACK: See above for docs. This should be removed soon. | |
_monkey_patch_writer_append_rows_stream(append_rows_stream) | |
offset = 0 | |
futures = [] | |
# Create an append rows request for each dataframe chunk. | |
for df in df_chunks: | |
# Turn the dataframe into a ProtoRows full of serialized protos using our dynamically generated protobuf class. | |
proto_rows = _create_rows_from_df( | |
df, | |
msg_class, | |
table_schema, | |
selected_columns=columns, | |
column_name_to_proto_field_name=column_name_to_proto_field_name, | |
) | |
request = types.AppendRowsRequest() | |
request.offset = offset | |
proto_data = types.AppendRowsRequest.ProtoData() | |
proto_data.rows = proto_rows | |
request.proto_rows = proto_data | |
futures.append(append_rows_stream.send(request)) | |
# Check to for too many pending futures and await them if needed. | |
if len(futures) >= max_concurrent_futures: | |
_await_futures(futures, column_name_to_proto_field_name) | |
futures = [] | |
# Advance the offset by the number of rows added for the next request should there be one. | |
offset += len(df.index) | |
# Wait for all append requests to finish. | |
_await_futures(futures, column_name_to_proto_field_name) | |
# Close the stream | |
append_rows_stream.close() | |
write_client.finalize_write_stream(name=write_stream.name) | |
# Commit all requests | |
batch_commit_write_streams_request = types.BatchCommitWriteStreamsRequest() | |
batch_commit_write_streams_request.parent = parent | |
batch_commit_write_streams_request.write_streams = [write_stream.name] | |
write_client.batch_commit_write_streams(batch_commit_write_streams_request) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# Generated by the protocol buffer compiler. DO NOT EDIT! | |
# source: google/cloud/bigquery/storage/v1/annotations.proto | |
# Protobuf Python Version: 4.25.3 | |
"""Generated protocol buffer code.""" | |
from google.protobuf import descriptor as _descriptor | |
from google.protobuf import descriptor_pool as _descriptor_pool | |
from google.protobuf import symbol_database as _symbol_database | |
from google.protobuf.internal import builder as _builder | |
# @@protoc_insertion_point(imports) | |
_sym_db = _symbol_database.Default() | |
from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 | |
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( | |
b"\n2google/cloud/bigquery/storage/v1/annotations.proto\x12 google.cloud.bigquery.storage.v1\x1a google/protobuf/descriptor.proto:9\n\x0b\x63olumn_name\x12\x1d.google.protobuf.FieldOptions\x18\xb5\xc3\xf7\xd8\x01 \x01(\t\x88\x01\x01\x42\xc0\x01\n$com.google.cloud.bigquery.storage.v1B\x10\x41nnotationsProtoP\x01Z>cloud.google.com/go/bigquery/storage/apiv1/storagepb;storagepb\xaa\x02 Google.Cloud.BigQuery.Storage.V1\xca\x02 Google\\Cloud\\BigQuery\\Storage\\V1b\x06proto3" | |
) | |
_globals = globals() | |
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) | |
_builder.BuildTopDescriptorsAndMessages( | |
DESCRIPTOR, "google.cloud.bigquery.storage.v1.annotations_pb2", _globals | |
) | |
if _descriptor._USE_C_DESCRIPTORS == False: | |
_globals["DESCRIPTOR"]._options = None | |
_globals["DESCRIPTOR"]._serialized_options = ( | |
b"\n$com.google.cloud.bigquery.storage.v1B\020AnnotationsProtoP\001Z>cloud.google.com/go/bigquery/storage/apiv1/storagepb;storagepb\252\002 Google.Cloud.BigQuery.Storage.V1\312\002 Google\\Cloud\\BigQuery\\Storage\\V1" | |
) | |
# @@protoc_insertion_point(module_scope) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment