Created
December 28, 2019 04:10
-
-
Save danyangz/57cb4c3e23a852c751d8c2d1ff4e350b to your computer and use it in GitHub Desktop.
Asynchronous gRPC full-duplex streaming example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <chrono> | |
#include <fstream> | |
#include <grpc/grpc.h> | |
#include <grpcpp/grpcpp.h> | |
#include <grpcpp/resource_quota.h> | |
#include <grpcpp/support/channel_arguments.h> | |
#include <string> | |
#include <thread> | |
#include <unistd.h> | |
#include "benchmark.grpc.pb.h" | |
using grpc::Channel; | |
using grpc::ClientAsyncReaderWriter; | |
using grpc::ClientContext; | |
using grpc::CompletionQueue; | |
using grpc::Status; | |
using benchmark::Ack; | |
using benchmark::Benchmark; | |
using benchmark::Data; | |
class BenchmarkClient { | |
public: | |
explicit BenchmarkClient(std::string bind_address, int num_threads) { | |
for (int i = 0; i < num_threads; i++) { | |
grpc::ResourceQuota quota; | |
quota.SetMaxThreads(4); | |
grpc::ChannelArguments argument; | |
argument.SetResourceQuota(quota); | |
auto channel = grpc::CreateCustomChannel( | |
bind_address, grpc::InsecureChannelCredentials(), argument); | |
stubs_.push_back(Benchmark::NewStub(channel)); | |
} | |
num_threads_ = num_threads; | |
} | |
void Run(const std::string &str, int payload_size) { | |
payload_size_ = payload_size; | |
int max_outgoing = 100; | |
cqs_.reserve(num_threads_); | |
polling_threads_.reserve(num_threads_); | |
for (int i = 0; i < num_threads_; i++) { | |
cqs_.emplace_back(); | |
} | |
for (int i = 0; i < num_threads_; i++) { | |
polling_threads_.emplace_back(&BenchmarkClient::PollCompletionQueue, this, | |
i, str); | |
} | |
for (int i = 0; i < max_outgoing; i++) { | |
AsyncClientCall *call = new AsyncClientCall; | |
int id = rand() % num_threads_; | |
call->stream = stubs_[id]->PrepareAsyncSendDataStreamFullDuplex( | |
&call->context, &cqs_[id]); | |
call->count = 0; | |
call->stream->StartCall((void *)call); | |
call->sendfinished = false; | |
call->finished = false; | |
call->count = 0; | |
call->writing = true; | |
} | |
for (int i = 0; i < num_threads_; i++) { | |
polling_threads_[i].join(); | |
} | |
} | |
void PollCompletionQueue(int id, const std::string &str) { | |
void *tag; | |
bool ok = false; | |
long onehm = 1024 * 1024 * 100; | |
long oneg = 1024 * 1024 * 1024; | |
long packets_to_report = oneg / payload_size_; | |
long packets_to_report_100m = onehm / payload_size_; | |
long c = 0; | |
auto start = std::chrono::system_clock::now(); | |
while (cqs_[id].Next(&tag, &ok)) { | |
AsyncClientCall *call = static_cast<AsyncClientCall *>(tag); | |
if (not ok) { | |
std::cout << "call = " << call << std::endl; | |
std::cout << "count " << call->count << std::endl; | |
std::cout << "sendfinished " << call->sendfinished << std::endl; | |
std::cout << "finished " << call->finished << std::endl; | |
} | |
GPR_ASSERT(ok); | |
if ((call->count < 10) && (call->writing == true)) { | |
Data d; | |
d.set_data(str); | |
call->writing = false; | |
call->stream->Write(d, (void *)call); | |
call->count ++; | |
continue; | |
} | |
if ((call->count <= 10) && (call->writing == false)) { | |
call->writing = true; | |
call->stream->Read(&call->ack, (void*) call); | |
continue; | |
} | |
if (not call->sendfinished) { | |
call->stream->WritesDone((void *)call); | |
call->sendfinished = true; | |
continue; | |
} | |
if (not call->finished) { | |
call->stream->Finish(&call->status, (void *)call); | |
call->finished = true; | |
continue; | |
} | |
if (call->status.ok()) { | |
c += 10; | |
if (c > packets_to_report_100m) { | |
auto end = std::chrono::system_clock::now(); | |
std::chrono::duration<double> diff = end - start; | |
double gbps = 8.0 * onehm / diff.count() / 1e9; | |
std::cout << id << ": " << gbps << " Gbps" << std::endl; | |
start = end; | |
c = 0; | |
} | |
} else { | |
std::cout << "RPC failed" << std::endl; | |
} | |
delete call; | |
call = new AsyncClientCall; | |
call->stream = stubs_[id]->PrepareAsyncSendDataStreamFullDuplex( | |
&call->context, &cqs_[id]); | |
call->stream->StartCall((void *)call); | |
call->sendfinished = false; | |
call->finished = false; | |
call->count = 0; | |
call->writing = true; | |
} | |
} | |
private: | |
struct AsyncClientCall { | |
Ack ack; | |
ClientContext context; | |
Status status; | |
std::unique_ptr<ClientAsyncReaderWriter<Data, Ack>> stream; | |
int count; | |
bool sendfinished; | |
bool finished; | |
bool writing; | |
}; | |
int num_threads_; | |
int payload_size_; | |
std::vector<std::unique_ptr<Benchmark::Stub>> stubs_; | |
std::vector<CompletionQueue> cqs_; | |
std::vector<std::thread> polling_threads_; | |
}; | |
int main(int argc, char **argv) { | |
size_t num_threads = std::stoi(argv[1]); | |
std::string ip = std::string(argv[2]); | |
int port = std::stoi(argv[3]); | |
size_t payload_size = std::stoi(argv[4]); | |
std::string a; | |
a.assign(payload_size, 'a'); | |
auto bind_address = ip + ":" + std::to_string(port); | |
BenchmarkClient client(bind_address, num_threads); | |
client.Run(a, payload_size); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
syntax = "proto3"; | |
package benchmark; | |
service Benchmark { | |
rpc SendDataStreamFullDuplex(stream Data) returns (stream Ack) {} | |
} | |
message Data { | |
bytes data = 1; | |
} | |
message Ack { | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <fstream> | |
#include <grpc/grpc.h> | |
#include <grpcpp/server.h> | |
#include <grpcpp/server_builder.h> | |
#include <grpcpp/server_context.h> | |
#include <iostream> | |
#include <string> | |
#include <thread> | |
#include <unistd.h> | |
#include "benchmark.grpc.pb.h" | |
using grpc::Server; | |
using grpc::ServerAsyncReaderWriter; | |
using grpc::ServerBuilder; | |
using grpc::ServerCompletionQueue; | |
using grpc::ServerContext; | |
using grpc::ServerReader; | |
using grpc::Status; | |
using benchmark::Ack; | |
using benchmark::Benchmark; | |
using benchmark::Data; | |
class BenchmarkAsyncImpl final { | |
public: | |
~BenchmarkAsyncImpl() { | |
server_->Shutdown(); | |
for (size_t i = 0; i < cqs_.size(); i++) { | |
cqs_[i]->Shutdown(); | |
} | |
} | |
void Run(std::string server_address, int port, size_t num_threads) { | |
ServerBuilder builder; | |
std::string bind_address = server_address + ":" + std::to_string(port); | |
builder.AddListeningPort(bind_address, grpc::InsecureServerCredentials()); | |
builder.RegisterService(&service_); | |
for (size_t i = 0; i < num_threads; i++) { | |
cqs_.push_back(std::move(builder.AddCompletionQueue())); | |
} | |
server_ = builder.BuildAndStart(); | |
std::cout << "Async server listening on " << server_address << std::endl; | |
std::vector<std::thread> threads; | |
threads.reserve(num_threads); | |
for (size_t i = 0; i < num_threads; i++) { | |
threads.emplace_back(HandleRPCs, &service_, cqs_[i].get()); | |
} | |
for (size_t i = 0; i < num_threads; i++) { | |
threads[i].join(); | |
} | |
} | |
private: | |
class HandleRPC { | |
public: | |
HandleRPC(Benchmark::AsyncService *service, ServerCompletionQueue *cq) | |
: service_(service), cq_(cq), stream_(&ctx_), status_(CREATE) { | |
P(nullptr, true); | |
} | |
void P(std::string *d, bool ok) { | |
if (status_ == CREATE) { | |
status_ = PROCESS; | |
service_->RequestSendDataStreamFullDuplex(&ctx_, &stream_, cq_, cq_, this); | |
} else if (status_ == PROCESS) { | |
new HandleRPC(service_, cq_); | |
status_ = PROCESSING_READ; | |
stream_.Read(&data_, this); | |
} else if (status_ == PROCESSING_READ) { | |
d->assign(data_.data()); | |
data_.Clear(); | |
if (ok) { | |
status_ = PROCESSING_WRITE; | |
stream_.Write(ack_, this); | |
} else { | |
status_ = FINISH; | |
stream_.Finish(Status::OK, this); | |
} | |
} else if (status_ == PROCESSING_WRITE) { | |
if (ok) { | |
status_ = PROCESSING_READ; | |
stream_.Read(&data_, this); | |
} else { | |
status_ = FINISH; | |
stream_.Finish(Status::OK, this); | |
} | |
} else { | |
GPR_ASSERT(status_ == FINISH); | |
delete this; | |
} | |
} | |
private: | |
Benchmark::AsyncService *service_; | |
ServerCompletionQueue *cq_; | |
ServerContext ctx_; | |
Data data_; | |
Ack ack_; | |
ServerAsyncReaderWriter<Ack, Data> stream_; | |
enum CallStatus { CREATE, PROCESS, PROCESSING_READ, PROCESSING_WRITE, FINISH }; | |
CallStatus status_; | |
}; | |
static void HandleRPCs(Benchmark::AsyncService *service, | |
ServerCompletionQueue *cq) { | |
new HandleRPC(service, cq); | |
void *tag; | |
bool ok; | |
std::string a; | |
while (true) { | |
GPR_ASSERT(cq->Next(&tag, &ok)); | |
static_cast<HandleRPC *>(tag)->P(&a, ok); | |
} | |
} | |
std::vector<std::unique_ptr<ServerCompletionQueue>> cqs_; | |
Benchmark::AsyncService service_; | |
std::unique_ptr<Server> server_; | |
}; | |
void RunServer(std::string ip, int port, size_t num_threads) { | |
BenchmarkAsyncImpl server; | |
server.Run(ip, port, num_threads); | |
} | |
int main(int argc, char **argv) { | |
std::string ip = "0.0.0.0"; | |
int port = atoi(argv[2]); | |
RunServer(ip, port, atoi(argv[1])); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment