Skip to content

Instantly share code, notes, and snippets.

@danyangz
Created December 28, 2019 04:10
Show Gist options
  • Save danyangz/57cb4c3e23a852c751d8c2d1ff4e350b to your computer and use it in GitHub Desktop.
Save danyangz/57cb4c3e23a852c751d8c2d1ff4e350b to your computer and use it in GitHub Desktop.
Asynchronous gRPC full-duplex streaming example
#include <chrono>
#include <fstream>
#include <grpc/grpc.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/resource_quota.h>
#include <grpcpp/support/channel_arguments.h>
#include <string>
#include <thread>
#include <unistd.h>
#include "benchmark.grpc.pb.h"
using grpc::Channel;
using grpc::ClientAsyncReaderWriter;
using grpc::ClientContext;
using grpc::CompletionQueue;
using grpc::Status;
using benchmark::Ack;
using benchmark::Benchmark;
using benchmark::Data;
class BenchmarkClient {
public:
explicit BenchmarkClient(std::string bind_address, int num_threads) {
for (int i = 0; i < num_threads; i++) {
grpc::ResourceQuota quota;
quota.SetMaxThreads(4);
grpc::ChannelArguments argument;
argument.SetResourceQuota(quota);
auto channel = grpc::CreateCustomChannel(
bind_address, grpc::InsecureChannelCredentials(), argument);
stubs_.push_back(Benchmark::NewStub(channel));
}
num_threads_ = num_threads;
}
void Run(const std::string &str, int payload_size) {
payload_size_ = payload_size;
int max_outgoing = 100;
cqs_.reserve(num_threads_);
polling_threads_.reserve(num_threads_);
for (int i = 0; i < num_threads_; i++) {
cqs_.emplace_back();
}
for (int i = 0; i < num_threads_; i++) {
polling_threads_.emplace_back(&BenchmarkClient::PollCompletionQueue, this,
i, str);
}
for (int i = 0; i < max_outgoing; i++) {
AsyncClientCall *call = new AsyncClientCall;
int id = rand() % num_threads_;
call->stream = stubs_[id]->PrepareAsyncSendDataStreamFullDuplex(
&call->context, &cqs_[id]);
call->count = 0;
call->stream->StartCall((void *)call);
call->sendfinished = false;
call->finished = false;
call->count = 0;
call->writing = true;
}
for (int i = 0; i < num_threads_; i++) {
polling_threads_[i].join();
}
}
void PollCompletionQueue(int id, const std::string &str) {
void *tag;
bool ok = false;
long onehm = 1024 * 1024 * 100;
long oneg = 1024 * 1024 * 1024;
long packets_to_report = oneg / payload_size_;
long packets_to_report_100m = onehm / payload_size_;
long c = 0;
auto start = std::chrono::system_clock::now();
while (cqs_[id].Next(&tag, &ok)) {
AsyncClientCall *call = static_cast<AsyncClientCall *>(tag);
if (not ok) {
std::cout << "call = " << call << std::endl;
std::cout << "count " << call->count << std::endl;
std::cout << "sendfinished " << call->sendfinished << std::endl;
std::cout << "finished " << call->finished << std::endl;
}
GPR_ASSERT(ok);
if ((call->count < 10) && (call->writing == true)) {
Data d;
d.set_data(str);
call->writing = false;
call->stream->Write(d, (void *)call);
call->count ++;
continue;
}
if ((call->count <= 10) && (call->writing == false)) {
call->writing = true;
call->stream->Read(&call->ack, (void*) call);
continue;
}
if (not call->sendfinished) {
call->stream->WritesDone((void *)call);
call->sendfinished = true;
continue;
}
if (not call->finished) {
call->stream->Finish(&call->status, (void *)call);
call->finished = true;
continue;
}
if (call->status.ok()) {
c += 10;
if (c > packets_to_report_100m) {
auto end = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end - start;
double gbps = 8.0 * onehm / diff.count() / 1e9;
std::cout << id << ": " << gbps << " Gbps" << std::endl;
start = end;
c = 0;
}
} else {
std::cout << "RPC failed" << std::endl;
}
delete call;
call = new AsyncClientCall;
call->stream = stubs_[id]->PrepareAsyncSendDataStreamFullDuplex(
&call->context, &cqs_[id]);
call->stream->StartCall((void *)call);
call->sendfinished = false;
call->finished = false;
call->count = 0;
call->writing = true;
}
}
private:
struct AsyncClientCall {
Ack ack;
ClientContext context;
Status status;
std::unique_ptr<ClientAsyncReaderWriter<Data, Ack>> stream;
int count;
bool sendfinished;
bool finished;
bool writing;
};
int num_threads_;
int payload_size_;
std::vector<std::unique_ptr<Benchmark::Stub>> stubs_;
std::vector<CompletionQueue> cqs_;
std::vector<std::thread> polling_threads_;
};
int main(int argc, char **argv) {
size_t num_threads = std::stoi(argv[1]);
std::string ip = std::string(argv[2]);
int port = std::stoi(argv[3]);
size_t payload_size = std::stoi(argv[4]);
std::string a;
a.assign(payload_size, 'a');
auto bind_address = ip + ":" + std::to_string(port);
BenchmarkClient client(bind_address, num_threads);
client.Run(a, payload_size);
return 0;
}
syntax = "proto3";
package benchmark;
service Benchmark {
rpc SendDataStreamFullDuplex(stream Data) returns (stream Ack) {}
}
message Data {
bytes data = 1;
}
message Ack {
}
#include <fstream>
#include <grpc/grpc.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include <grpcpp/server_context.h>
#include <iostream>
#include <string>
#include <thread>
#include <unistd.h>
#include "benchmark.grpc.pb.h"
using grpc::Server;
using grpc::ServerAsyncReaderWriter;
using grpc::ServerBuilder;
using grpc::ServerCompletionQueue;
using grpc::ServerContext;
using grpc::ServerReader;
using grpc::Status;
using benchmark::Ack;
using benchmark::Benchmark;
using benchmark::Data;
class BenchmarkAsyncImpl final {
public:
~BenchmarkAsyncImpl() {
server_->Shutdown();
for (size_t i = 0; i < cqs_.size(); i++) {
cqs_[i]->Shutdown();
}
}
void Run(std::string server_address, int port, size_t num_threads) {
ServerBuilder builder;
std::string bind_address = server_address + ":" + std::to_string(port);
builder.AddListeningPort(bind_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service_);
for (size_t i = 0; i < num_threads; i++) {
cqs_.push_back(std::move(builder.AddCompletionQueue()));
}
server_ = builder.BuildAndStart();
std::cout << "Async server listening on " << server_address << std::endl;
std::vector<std::thread> threads;
threads.reserve(num_threads);
for (size_t i = 0; i < num_threads; i++) {
threads.emplace_back(HandleRPCs, &service_, cqs_[i].get());
}
for (size_t i = 0; i < num_threads; i++) {
threads[i].join();
}
}
private:
class HandleRPC {
public:
HandleRPC(Benchmark::AsyncService *service, ServerCompletionQueue *cq)
: service_(service), cq_(cq), stream_(&ctx_), status_(CREATE) {
P(nullptr, true);
}
void P(std::string *d, bool ok) {
if (status_ == CREATE) {
status_ = PROCESS;
service_->RequestSendDataStreamFullDuplex(&ctx_, &stream_, cq_, cq_, this);
} else if (status_ == PROCESS) {
new HandleRPC(service_, cq_);
status_ = PROCESSING_READ;
stream_.Read(&data_, this);
} else if (status_ == PROCESSING_READ) {
d->assign(data_.data());
data_.Clear();
if (ok) {
status_ = PROCESSING_WRITE;
stream_.Write(ack_, this);
} else {
status_ = FINISH;
stream_.Finish(Status::OK, this);
}
} else if (status_ == PROCESSING_WRITE) {
if (ok) {
status_ = PROCESSING_READ;
stream_.Read(&data_, this);
} else {
status_ = FINISH;
stream_.Finish(Status::OK, this);
}
} else {
GPR_ASSERT(status_ == FINISH);
delete this;
}
}
private:
Benchmark::AsyncService *service_;
ServerCompletionQueue *cq_;
ServerContext ctx_;
Data data_;
Ack ack_;
ServerAsyncReaderWriter<Ack, Data> stream_;
enum CallStatus { CREATE, PROCESS, PROCESSING_READ, PROCESSING_WRITE, FINISH };
CallStatus status_;
};
static void HandleRPCs(Benchmark::AsyncService *service,
ServerCompletionQueue *cq) {
new HandleRPC(service, cq);
void *tag;
bool ok;
std::string a;
while (true) {
GPR_ASSERT(cq->Next(&tag, &ok));
static_cast<HandleRPC *>(tag)->P(&a, ok);
}
}
std::vector<std::unique_ptr<ServerCompletionQueue>> cqs_;
Benchmark::AsyncService service_;
std::unique_ptr<Server> server_;
};
void RunServer(std::string ip, int port, size_t num_threads) {
BenchmarkAsyncImpl server;
server.Run(ip, port, num_threads);
}
int main(int argc, char **argv) {
std::string ip = "0.0.0.0";
int port = atoi(argv[2]);
RunServer(ip, port, atoi(argv[1]));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment