Last active
September 20, 2019 07:42
-
-
Save sllynn/44063873c6201bf8aae9b6486b97fa87 to your computer and use it in GitHub Desktop.
kinesis writer (includes some other logic relevant to multiclass classification of documents)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import boto3 | |
import json | |
import numpy as np | |
import pandas as pd | |
from math import ceil | |
class KinesisWriter: | |
def __init__(self, region, stream, classes): | |
self.kinesis_client = None | |
self.kinesis_region = region | |
self.kinesis_stream = stream | |
self.classes = classes | |
def open(self, partition_id=None, epoch_id=None): | |
self.kinesis_client = boto3.client("kinesis", region_name=self.kinesis_region) | |
def process(self, batch_df, batch_id): | |
if batch_df.count() == 0: | |
return False | |
local_df = batch_df.toPandas() | |
filtered_df = local_df[local_df.predicted_class.isin(self.classes)] | |
try: | |
records = [ | |
dict( | |
Data=json.dumps( | |
dict(msgid=rw["id"], msg=rw["text"], | |
class_pred="{0}".format(rw["predicted_class"]), | |
marginal_prob="{0:.3f}".format(rw["marginal_confidence"])) | |
), PartitionKey=rw["id"] | |
) for _, rw in filtered_df.iterrows() | |
] | |
if len(records) == 0: | |
return False | |
if len(records) > 500: | |
for records_page in np.array_split(records, ceil(len(records) / 500)): | |
self.kinesis_client.put_records(Records=records_page.tolist(), StreamName=self.kinesis_stream) | |
else: | |
self.kinesis_client.put_records(Records=records, StreamName=self.kinesis_stream) | |
return True | |
except (TypeError, ValueError, KeyError) as e: | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment