Skip to content

Instantly share code, notes, and snippets.

@film42
Created October 24, 2024 22:08
Show Gist options
  • Save film42/64d4835fe37b15a8f889f9bc61325b08 to your computer and use it in GitHub Desktop.
Save film42/64d4835fe37b15a8f889f9bc61325b08 to your computer and use it in GitHub Desktop.
Write a dataframe to BigQuery using the Storage Write API with Flexible Columns with dynamic protos
# NOTE: The BigQueryTable class is something we wrote to wrap a bunch of operations and migrate tables,
# but you can catch the gist of what we're doing. Converting types is not the hard part.
#
# NOTE: The protobuf descritpor comes from: https://github.com/googleapis/googleapis/blob/master/google/cloud/bigquery/storage/v1/annotations.proto
import big_query_storage_write_api as s
s.to_gbq(
df=df,
table=internal.BigQueryTable("some_table"),
)
from google.cloud import bigquery
import pandas as pd
import banzai.google as google
from datetime import datetime
import json
import numpy as np
from banzai.google.big_query_table import BigQueryTable
from google.protobuf import descriptor_pb2, descriptor_pool, message_factory
from google.cloud import bigquery_storage_v1
from google.cloud.bigquery_storage_v1 import types, writer
from google.api_core.exceptions import InvalidArgument
import re
import hashlib
def generate_protobuf_safe_column_name(column_name):
# Only keep alpha numeric
masked = re.sub(r"[^a-zA-Z0-9]", "_", column_name)
# Create a SHA1 hash of the original string to avoid duplicate column names after
# scrubbing alpha-numeric filtering.
hash_obj = hashlib.sha1(column_name.encode())
hash_str = hash_obj.hexdigest()[:10]
# Combine masked string with hash
return f"column__{masked}_{hash_str}"
# Import bigquery storage annotations so we can set the column_name to a flexible column name.
# See the code example in docs: https://cloud.google.com/bigquery/docs/schemas#flexible-column-names
#
# Load the bigquery storage annotations into the default descriptor pool.
import banzai.google.google_cloud_bigquery_storage_v1_annotations
def get_column_name_extension():
return descriptor_pool.Default().FindExtensionByName(
"google.cloud.bigquery.storage.v1.column_name"
)
# HACK: I have to monkey patch AppendRowsStream until the following issue is closed
# and the response is accesssible from an append rows request exception.
#
# GH: https://github.com/googleapis/python-bigquery-storage/issues/836
#
def _monkey_patch_writer_append_rows_stream(w: writer.AppendRowsStream):
from google.api_core import exceptions
from google.cloud.bigquery_storage_v1 import exceptions as bqstorage_exceptions
# Take from my patch here: https://github.com/googleapis/python-bigquery-storage/pull/838
def _on_response(self, response: types.AppendRowsResponse):
"""Process a response from a consumer callback."""
# If the stream has closed, but somehow we still got a response message
# back, discard it. The response futures queue has been drained, with
# an exception reported.
if self._closed:
raise bqstorage_exceptions.StreamClosedError(
f"Stream closed before receiving response: {response}"
)
# Since we have 1 response per request, if we get here from a response
# callback, the queue should never be empty.
future: writer.AppendRowsFuture = self._futures_queue.get_nowait()
if response.error.code:
exc = exceptions.from_grpc_status(
response.error.code, response.error.message, response=response
)
future.set_exception(exc)
else:
future.set_result(response)
w._on_response = _on_response.__get__(w, type(w))
class BigQueryDirectStorageCopyException(Exception):
pass
class BigQueryDirectStorageCopyBadResponseException(Exception):
def __init__(
self,
exception,
flexible_to_strict_column_names={},
proto_class="DynamicMessage",
):
self.exception = exception
self.response = exception.response
m = f"{self.exception}\nResponse:\n{self.response}"
# A bit of a hack to clean up the error messages so they show the flexible column and not the
# generated proto friendly column.
for (
flexible_col_name,
strict_col_name,
) in flexible_to_strict_column_names.items():
m = m.replace(f"{proto_class}.{strict_col_name}", flexible_col_name)
self.error_message = m
super().__init__(m)
def to_gbq(
df: pd.DataFrame, table: BigQueryTable, batch_size=501, max_concurrent_futures=200
):
_append_df_to_bq_table(
df,
table=table,
batch_size=batch_size,
max_concurrent_futures=max_concurrent_futures,
)
# NOTE: I used a stackoverflow article to write a bunch of this code:
# https://stackoverflow.com/questions/77428218/creating-a-protobuf-factory-for-a-dynamically-generated-message
def _create_dynamic_pb_for_schema(table_schema, selected_columns):
message_descriptor_proto = descriptor_pb2.DescriptorProto()
message_descriptor_proto.name = "DynamicMessage"
field_num = 0
column_name_to_proto_field_name = {}
for column_name, schema_field in table_schema.items():
# Skip anything that's not in the table and the dataframe.
if column_name not in selected_columns:
continue
field_num += 1
proto_field_name = generate_protobuf_safe_column_name(column_name)
column_name_to_proto_field_name[column_name] = proto_field_name
# NOTE: The documentation for mapping big query types into protobuf types is found here:
# https://cloud.google.com/bigquery/docs/write-api#data_type_conversions
if (
schema_field.field_type == "STRING"
or schema_field.field_type == "DATE"
or schema_field.field_type == "DATETIME"
or schema_field.field_type == "JSON"
or schema_field.field_type == "TIME"
):
field = message_descriptor_proto.field.add(
name=proto_field_name,
number=field_num,
type=descriptor_pb2.FieldDescriptorProto.TYPE_STRING,
)
elif schema_field.field_type == "FLOAT":
field = message_descriptor_proto.field.add(
name=proto_field_name,
number=field_num,
type=descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE,
)
elif (
schema_field.field_type == "INTEGER"
or schema_field.field_type == "TIMESTAMP"
):
field = message_descriptor_proto.field.add(
name=proto_field_name,
number=field_num,
type=descriptor_pb2.FieldDescriptorProto.TYPE_INT64,
)
elif schema_field.field_type == "BOOLEAN":
field = message_descriptor_proto.field.add(
name=proto_field_name,
number=field_num,
type=descriptor_pb2.FieldDescriptorProto.TYPE_BOOL,
)
else:
field_type = schema_field.field_type
raise BigQueryDirectStorageCopyException(
f"Found unmapped field_type: {field_type}"
)
ext_column_name = get_column_name_extension()
# Set the BigQuery column_name annotation which takes priority.
field_options = descriptor_pb2.FieldOptions()
field_options.Extensions[ext_column_name] = column_name
field.options.CopyFrom(field_options)
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
file_descriptor_proto.name = "dynamic_message.proto"
file_descriptor_proto.package = "dynamic_package"
file_descriptor_proto.message_type.add().CopyFrom(message_descriptor_proto)
pool = descriptor_pool.DescriptorPool()
file_descriptor = pool.Add(file_descriptor_proto)
message_class = message_factory.GetMessages([file_descriptor_proto])[
"dynamic_package.DynamicMessage"
]
return message_class, column_name_to_proto_field_name
def _type_cast_column(column_name, new_type, value):
try:
return new_type(value)
except ValueError as e:
raise ValueError(
f'Error casting type {value.__class__.__name__} to {new_type.__name__} for column "{column_name}": {e}'
)
def _create_rows_from_df(
df, msg_class, table_schema, selected_columns, column_name_to_proto_field_name
):
proto_rows = types.ProtoRows()
for _, row in df.iterrows():
msg = msg_class()
for column_name, value in row.items():
# Skip the non-selected columns
if column_name not in selected_columns:
continue
# Skip nan values since those are implicit in protobuf.
if pd.isna(value):
continue
# We need to cast date/datetime to an iso 8601 string for BQ.
if (
table_schema[column_name].field_type == "DATE"
or table_schema[column_name].field_type == "DATETIME"
):
if isinstance(value, datetime):
# Drop the timezone so that any datetime string will not have "+00:00" appended
# since it's not a valid datetime literal in BigQuery.
value = value.replace(tzinfo=None)
if hasattr(value, "isoformat"):
value = value.isoformat()
if table_schema[column_name].field_type == "FLOAT":
value = _type_cast_column(column_name, float, value)
# Handle time -> string conversion for datetime and time types.
if table_schema[column_name].field_type == "TIME":
# If we received a datetime, we should drop the date and only keep the time part.
if isinstance(value, datetime):
value = value.time()
if hasattr(value, "isoformat"):
value = value.isoformat()
if table_schema[column_name].field_type == "INTEGER":
value = _type_cast_column(column_name, int, value)
# Handle timestamp -> int64 conversion for datetime types.
if table_schema[column_name].field_type == "TIMESTAMP":
if isinstance(value, datetime):
# We want a unix timestamp in microseconds.
value = int(datetime.timestamp(value) * 1000000)
# Handle dumping a json string if we don't detect a string present.
if table_schema[column_name].field_type == "JSON":
if not isinstance(value, str):
value = json.dumps(value)
if table_schema[column_name].field_type == "STRING":
value = _type_cast_column(column_name, str, value)
# Set the attribute on the proto
try:
setattr(msg, column_name_to_proto_field_name[column_name], value)
except TypeError as e:
# The lone type error is not helpful, so let's re-wrap and re-raise.
raise TypeError(f"Error setting {column_name} on proto for upload: {e}")
proto_rows.serialized_rows.append(msg.SerializeToString())
return proto_rows
def _await_futures(futures, column_name_to_proto_field_name):
for future in futures:
try:
future.result()
except InvalidArgument as e:
raise BigQueryDirectStorageCopyBadResponseException(
e, flexible_to_strict_column_names=column_name_to_proto_field_name
)
def _append_df_to_bq_table(
df: pd.DataFrame, table, batch_size=500, max_concurrent_futures=200
):
table_schema = {s.name: s for s in table.schema_fields()}
table_columns = set(table_schema.keys())
# We only want to select columns that are both in the table and in the dataframe.
# NOTE: This probably already happens when enforce_schema is used but just being a little safe.
columns = table_columns.intersection(set(df.columns))
proto_schema = types.ProtoSchema()
# Convert the DF into a series of serialized rows using a dynamic proto
msg_class, column_name_to_proto_field_name = _create_dynamic_pb_for_schema(
table_schema=table_schema, selected_columns=columns
)
# NOTE: The batch size cannot exceed 10MB.
df_chunks = [df[i : i + batch_size] for i in range(0, len(df), batch_size)]
# NOTE: Here are some useful links to review this code:
#
# https://github.com/googleapis/python-bigquery-storage
# The bigquery storage client codebase is useful to understand what's happening behind each call.
#
# https://cloud.google.com/bigquery/docs/write-api-batch
# Example impl of a python batch write client
# Start streaming the data
write_client = bigquery_storage_v1.BigQueryWriteClient()
parent = write_client.table_path(
table.get_project_id(), table.get_dataset_id(), table.name
)
write_stream = types.WriteStream()
# Set up the pending stream for batch import
write_stream.type_ = types.WriteStream.Type.PENDING
write_stream = write_client.create_write_stream(
parent=parent, write_stream=write_stream
)
stream_name = write_stream.name
request_template = types.AppendRowsRequest()
request_template.write_stream = stream_name
# Wire up the descriptor and add the row data
proto_schema = types.ProtoSchema()
proto_descriptor = descriptor_pb2.DescriptorProto()
msg_class.DESCRIPTOR.CopyToProto(proto_descriptor)
proto_schema.proto_descriptor = proto_descriptor
proto_data = types.AppendRowsRequest.ProtoData()
proto_data.writer_schema = proto_schema
request_template.proto_rows = proto_data
append_rows_stream = writer.AppendRowsStream(write_client, request_template)
# HACK: See above for docs. This should be removed soon.
_monkey_patch_writer_append_rows_stream(append_rows_stream)
offset = 0
futures = []
# Create an append rows request for each dataframe chunk.
for df in df_chunks:
# Turn the dataframe into a ProtoRows full of serialized protos using our dynamically generated protobuf class.
proto_rows = _create_rows_from_df(
df,
msg_class,
table_schema,
selected_columns=columns,
column_name_to_proto_field_name=column_name_to_proto_field_name,
)
request = types.AppendRowsRequest()
request.offset = offset
proto_data = types.AppendRowsRequest.ProtoData()
proto_data.rows = proto_rows
request.proto_rows = proto_data
futures.append(append_rows_stream.send(request))
# Check to for too many pending futures and await them if needed.
if len(futures) >= max_concurrent_futures:
_await_futures(futures, column_name_to_proto_field_name)
futures = []
# Advance the offset by the number of rows added for the next request should there be one.
offset += len(df.index)
# Wait for all append requests to finish.
_await_futures(futures, column_name_to_proto_field_name)
# Close the stream
append_rows_stream.close()
write_client.finalize_write_stream(name=write_stream.name)
# Commit all requests
batch_commit_write_streams_request = types.BatchCommitWriteStreamsRequest()
batch_commit_write_streams_request.parent = parent
batch_commit_write_streams_request.write_streams = [write_stream.name]
write_client.batch_commit_write_streams(batch_commit_write_streams_request)
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: google/cloud/bigquery/storage/v1/annotations.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b"\n2google/cloud/bigquery/storage/v1/annotations.proto\x12 google.cloud.bigquery.storage.v1\x1a google/protobuf/descriptor.proto:9\n\x0b\x63olumn_name\x12\x1d.google.protobuf.FieldOptions\x18\xb5\xc3\xf7\xd8\x01 \x01(\t\x88\x01\x01\x42\xc0\x01\n$com.google.cloud.bigquery.storage.v1B\x10\x41nnotationsProtoP\x01Z>cloud.google.com/go/bigquery/storage/apiv1/storagepb;storagepb\xaa\x02 Google.Cloud.BigQuery.Storage.V1\xca\x02 Google\\Cloud\\BigQuery\\Storage\\V1b\x06proto3"
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "google.cloud.bigquery.storage.v1.annotations_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals["DESCRIPTOR"]._options = None
_globals["DESCRIPTOR"]._serialized_options = (
b"\n$com.google.cloud.bigquery.storage.v1B\020AnnotationsProtoP\001Z>cloud.google.com/go/bigquery/storage/apiv1/storagepb;storagepb\252\002 Google.Cloud.BigQuery.Storage.V1\312\002 Google\\Cloud\\BigQuery\\Storage\\V1"
)
# @@protoc_insertion_point(module_scope)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment