Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Last active August 28, 2024 23:18
Show Gist options
  • Save jweinst1/7ffd201e95ed91687ffb11d998ed33b0 to your computer and use it in GitHub Desktop.
Save jweinst1/7ffd201e95ed91687ffb11d998ed33b0 to your computer and use it in GitHub Desktop.
leaderless consensus
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <signal.h>
#include <errno.h>
//--------system headers -------//
#include <unistd.h>
#include <sys/un.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <optional>
#include <vector>
#include <unordered_map>
#include <string>
#include <variant>
static char* str_dupl(const char* src) {
size_t src_size = strlen(src) + 1;
char* newstr = (char*)malloc(src_size);
memcpy(newstr, src, src_size);
return newstr;
}
static void debugByteVector(const std::vector<unsigned char>& vec) {
for (const auto& byte : vec)
{
printf("%u ", byte);
}
printf("|\n");
}
static int getAndResetErrNo() {
int eResult = errno;
errno = 0;
return eResult;
}
static void resetErrNo() {
errno = 0;
}
static int errNoIsWouldBlock() {
int eResult = errno;
return eResult == EAGAIN || eResult == EWOULDBLOCK;
}
static void exitAndErrorNo(const char* lastAction) {
fprintf(stderr, "Got unexpected lastAction=%s, errno=%d\n", lastAction, errno);
exit(2);
}
static constexpr size_t getMaxSizeOfUnixPath() {
constexpr struct sockaddr_un unix_addr = {};
return sizeof(unix_addr.sun_path);
}
static const inline bool pathExists(const char* path) {
struct stat sbuf;
return stat(path, &sbuf) == 0;
}
static bool set_non_blocking(int sockfd, bool blocking) {
int flags = fcntl(sockfd, F_GETFL, 0);
flags = blocking ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK);
if (fcntl(sockfd, F_SETFL, flags)) {
return false;
}
return true;
}
static std::optional<int> create_server_socket(const char* path) {
int sfd = -1;
struct sockaddr_un unix_addr;
if (strlen(path) > getMaxSizeOfUnixPath() - 1) {
return std::nullopt;
}
if (pathExists(path)) {
remove(path);
}
memset(&unix_addr, 0, sizeof(struct sockaddr_un));
unix_addr.sun_family = AF_UNIX;
strncpy(unix_addr.sun_path, path, getMaxSizeOfUnixPath() - 1);
sfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd == -1) {
return std::nullopt;
}
if (bind(sfd, (struct sockaddr *) &unix_addr, sizeof(unix_addr)) == -1) {
return std::nullopt;
}
if (listen(sfd, 5) == -1) {
return std::nullopt;
}
return std::make_optional<int>(sfd);
}
static std::optional<int> create_client_socket(const char* path) {
int cfd = -1;
struct sockaddr_un unix_addr;
cfd = socket(AF_UNIX, SOCK_STREAM, 0);
if (cfd == -1) {
return std::nullopt;
}
if (strlen(path) > getMaxSizeOfUnixPath() - 1) {
// todo err
return std::nullopt;
}
memset(&unix_addr, 0, sizeof(struct sockaddr_un));
unix_addr.sun_family = AF_UNIX;
strncpy(unix_addr.sun_path, path, getMaxSizeOfUnixPath() - 1);
if (connect(cfd, (struct sockaddr *) &unix_addr, sizeof(unix_addr)) == -1) {
return std::nullopt;
}
return std::make_optional<int>(cfd);
}
static constexpr unsigned char PROT_U8 = 1;
static constexpr unsigned char PROT_U32 = 2;
static constexpr unsigned char PROT_U64 = 3;
static constexpr unsigned char PROT_STR = 4;
static constexpr unsigned char PROT_STR_LST = 5;
typedef std::variant<unsigned char, uint32_t, uint64_t, std::string, std::vector<std::string>> ReqItem;
static void buildFromBytes(std::vector<ReqItem>& vec, const unsigned char* bytes, size_t size) {
size_t i = 0;
while (i < size) {
ReqItem elem;
if (bytes[i] == PROT_U8) {
i += 1;
unsigned char b = bytes[i];
elem = b;
i += 1;
} else if (bytes[i] == PROT_U32) {
i += 1;
uint32_t val = 0;
memcpy(&val, bytes + i, sizeof(val));
elem = val;
i += sizeof(val);
} else if (bytes[i] == PROT_U64) {
i += 1;
uint64_t val = 0;
memcpy(&val, bytes + i, sizeof(val));
elem = val;
i += sizeof(val);
} else if (bytes[i] == PROT_STR) {
i += 1;
uint32_t strSize = 0;
memcpy(&strSize, bytes + i, sizeof(strSize));
i += sizeof(strSize);
std::string strObj;
strObj.resize(strSize);
memcpy(strObj.data(), bytes + i, strSize);
elem = strObj;
i += strSize;
} else if (bytes[i] == PROT_STR_LST) {
i += 1;
uint32_t strLstSize = 0;
memcpy(&strLstSize, bytes + i, sizeof(strLstSize));
i += sizeof(strLstSize);
std::vector<std::string> strLst;
for (size_t j = 0; j < strLstSize; ++j) {
uint32_t strSize = 0;
memcpy(&strSize, bytes + i, sizeof(strSize));
i += sizeof(strSize);
std::string strObj;
strObj.resize(strSize);
memcpy(strObj.data(), bytes + i, strSize);
i += strSize;
strLst.push_back(strObj);
}
elem = strLst;
} else {
fprintf(stderr, "Unexpected byte code during serialization, %u\n", bytes[i]);
exit(2);
}
vec.push_back(elem);
}
}
class ReqBuilder {
public:
void putSize() {
uint32_t sizeOfReq = _req.size() - sizeof(uint32_t);
memcpy(_req.data(), &sizeOfReq, sizeof(sizeOfReq));
}
void pushU8(unsigned char byte) {
_req.push_back(PROT_U8);
_req.push_back(byte);
}
void pushU32(uint32_t num) {
_req.push_back(PROT_U32);
size_t oldSize = _req.size();
_req.resize(oldSize + sizeof(num));
memcpy(_req.data() + oldSize, &num, sizeof(num));
}
void pushU64(uint64_t num) {
_req.push_back(PROT_U64);
size_t oldSize = _req.size();
_req.resize(oldSize + sizeof(num));
memcpy(_req.data() + oldSize, &num, sizeof(num));
}
void pushStr(const std::string& stringObj) {
_req.push_back(PROT_STR);
uint32_t strSize = stringObj.size();
size_t oldSize = _req.size();
_req.resize(oldSize + strSize + sizeof(strSize));
memcpy(_req.data() + oldSize, &strSize, sizeof(strSize));
memcpy(_req.data() + oldSize + sizeof(strSize), stringObj.data(), strSize);
}
void pushStrList(const std::vector<std::string>& stringLst) {
_req.push_back(PROT_STR_LST);
size_t oldSize = _req.size();
uint32_t strLstSize = stringLst.size();
size_t totalSize = calcSizeOfStrLst(stringLst);
_req.resize(oldSize + sizeof(strLstSize) + totalSize);
memcpy(_req.data() + oldSize, &strLstSize, sizeof(strLstSize));
size_t writePoint = oldSize + sizeof(strLstSize);
for (const auto& obj : stringLst) {
uint32_t strSize = obj.size();
memcpy(_req.data() + writePoint, &strSize, sizeof(strSize));
writePoint += sizeof(strSize);
memcpy(_req.data() + writePoint, obj.data(), strSize);
writePoint += strSize;
}
}
const std::vector<unsigned char>& getReq() const { return _req; }
const unsigned char* getReqData() const {
return _req.data() + sizeof(uint32_t);
}
size_t getReqSize() const { return _req.size() - sizeof(uint32_t); }
const unsigned char* getTotalReqData() const {
return _req.data();
}
size_t getTotalReqSize() const {
return _req.size();
}
private:
size_t calcSizeOfStrLst(const std::vector<std::string>& stringLst) {
size_t total = 0;
for (const auto& obj : stringLst) {
total += obj.size();
total += sizeof(uint32_t); // for size marker
}
return total;
}
std::vector<unsigned char> _req = {0, 0, 0, 0};
};
static constexpr unsigned char OPER_INTRO = 1;
static constexpr unsigned char OPER_INTRO_LIST = 2;
static constexpr unsigned char OPER_INTRO_CONN = 3;
static constexpr unsigned char OPER_INTRO_CONN_OK = 4;
static constexpr unsigned char OPER_INTRO_COMPLETE = 5;
struct ClusterRequest {
std::optional<std::string> sender; // todo make this a variant
std::optional<int> conn;
std::vector<ReqItem> req;
};
static constexpr int MEMB_STATE_BEGIN = 0;
static constexpr int MEMB_STATE_INTRO_REC = 1;
static constexpr int MEMB_STATE_INTRO_PROCESSING = 2;
static constexpr int MEMB_STATE_INTRO_SENT = 3;
static constexpr int MEMB_STATE_INTRO_CONN_SENT = 4;
static constexpr int MEMB_STATE_INTRO_LIST = 5;
static constexpr int MEMB_STATE_CONN = 6;
struct Member {
int fd = -1;
int state = MEMB_STATE_BEGIN;
};
class ClusterNode {
public:
ClusterNode(const char* path){
_path = path;
}
~ClusterNode(){
close(_server);
for (const auto& [ key, conn ] : _members) {
close(conn.fd);
}
}
void start() {
std::optional<int> fd = create_server_socket(_path.c_str());
if (fd.has_value()) {
_server = *fd;
assert(set_non_blocking(_server, true));
_isStarted = true;
} else {
fprintf(stderr, "Cannot create server socket, errno=%d\n", getAndResetErrNo());
exit(2);
}
}
const std::string& getPath() const { return _path; }
void collectRequests() {
assert(_isStarted);
struct sockaddr_un remote;
unsigned int sock_len = 0;
int incoming = accept(_server, (struct sockaddr*)&remote, &sock_len);
// todo make loop
if( incoming == -1 ) {
if(errNoIsWouldBlock()) {
resetErrNo();
} else {
exitAndErrorNo("Could not bind socket");
}
} else {
set_non_blocking(incoming, true);
_pending.push_back(incoming);
}
std::vector<int> toKeep;
// manual polling for now
for (size_t i = 0; i < _pending.size(); ++i) {
std::vector<unsigned char> bytes;
uint32_t req_size = 0;
read(_pending[i], &req_size, sizeof(req_size));
if (errNoIsWouldBlock()) {
resetErrNo();
toKeep.push_back(_pending[i]);
continue;
}
bytes.resize(req_size);
read(_pending[i], bytes.data(), req_size);
if (errNoIsWouldBlock()) {
exitAndErrorNo("Unexpected lack of body of request on pending");
}
std::vector<ReqItem> reqItems;
buildFromBytes(reqItems, bytes.data(), bytes.size());
ClusterRequest req;
req.conn = std::make_optional<int>(_pending[i]);
req.req = reqItems;
_requests.push_back(req);
}
_pending = toKeep;
for (const auto& [ key, conn ] : _members) {
std::vector<unsigned char> bytes;
uint32_t req_size = 0;
read(conn.fd, &req_size, sizeof(req_size));
if (errNoIsWouldBlock()) {
resetErrNo();
continue;
}
bytes.resize(req_size);
read(conn.fd, bytes.data(), req_size);
if (errNoIsWouldBlock()) {
exitAndErrorNo("Unexpected lack of body of request on formed member");
}
std::vector<ReqItem> reqItems;
buildFromBytes(reqItems, bytes.data(), bytes.size());
ClusterRequest req;
req.sender = std::make_optional<std::string>(key);
req.req = reqItems;
_requests.push_back(req);
}
}
void listMembersToVec(std::vector<std::string>& membs) {
for (const auto& [ key, conn ] : _members) {
membs.push_back(key);
}
}
void processIntroCompleteRequest(const ClusterRequest& req) {
if(!req.sender.has_value()) {
fprintf(stderr, "Got intro complete from non existing member\n");
exit(2);
}
const std::string sender = req.sender.value();
const auto toUpdate = _members.find(sender);
assert(toUpdate != _members.end());
if (toUpdate->second.state != MEMB_STATE_INTRO_REC) {
fprintf(stderr, "Expected member %s to be in intro rec state, was in state %d\n", sender.c_str(), toUpdate->second.state);
exit(2);
}
toUpdate->second.state = MEMB_STATE_CONN;
// no longer in processing state, go back
_myState = MEMB_STATE_CONN;
}
void processIntroConnOkRequest(const ClusterRequest& req) {
if(!req.sender.has_value()) {
fprintf(stderr, "Got conn ok from non existing member\n");
exit(2);
}
const std::string sender = req.sender.value();
const auto toUpdate = _members.find(sender);
assert(toUpdate != _members.end());
toUpdate->second.state = MEMB_STATE_CONN;
if (checkAndPossiblyCompleteIntroduction()) {
printf("The introduction for myself %s is complete\n", _path.c_str());
}
}
void processIntroConnRequest(const ClusterRequest& req) {
const std::string sender = std::get<std::string>(req.req[1]);
// handle case of processing here TODO
if (req.conn.has_value()) {
Member m;
m.fd = req.conn.value();
m.state = MEMB_STATE_CONN;
ReqBuilder build;
build.pushU8(OPER_INTRO_CONN_OK);
build.putSize();
write(m.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
fprintf(stderr, "Got unexpected would block when sending intro conn ok to %s", sender.c_str());
exit(2);
}
_members[sender] = m;
} else if (req.sender.has_value()) {
// not valid, this should only come from new member
fprintf(stderr, "Unexpected connected member '%s' sent intro conn request\n", req.sender.value().c_str());
exit(2);
} else {
fprintf(stderr, "Got neither connection nor sender when hanndling intro conn\n");
exit(4);
}
}
bool checkAndPossiblyCompleteIntroduction() {
assert(_myState == MEMB_STATE_INTRO_LIST);
for (const auto& [ key, conn ] : _members) {
if (conn.state == MEMB_STATE_INTRO_CONN_SENT) {
// not ready, still waiting for connection oks
return false;
}
}
for (auto& [ key, conn ] : _members) {
if (conn.state == MEMB_STATE_INTRO_SENT) {
// send completion request
ReqBuilder build;
build.pushU8(OPER_INTRO_COMPLETE);
build.putSize();
write(conn.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
fprintf(stderr, "Got unexpected would block when sending intro complete to %s", key.c_str());
exit(2);
}
conn.state = MEMB_STATE_CONN;
_myState = MEMB_STATE_CONN;
return true;
}
}
return false;
}
void processIntroListRequest(const ClusterRequest& req) {
const std::vector<std::string> clusterMembers = std::get<std::vector<std::string>>(req.req[1]);
ReqBuilder build;
build.pushU8(OPER_INTRO_CONN);
build.pushStr(_path);
build.putSize();
for (const auto& memb : clusterMembers) {
if (_members.find(memb) != _members.end()) {
fprintf(stderr, "Got intro list for '%s', but is already known member", memb.c_str());
exit(2);
}
std::optional<int> remote = create_client_socket(memb.c_str());
if (!remote.has_value()) {
fprintf(stderr, "Failed to connect to %s, errno=%d", memb.c_str(), getAndResetErrNo());
exit(2);
}
Member m;
m.fd = remote.value();
write(m.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
fprintf(stderr, "Got unexpected would block when sending intro resp to %s", memb.c_str());
exit(2);
}
m.state = MEMB_STATE_INTRO_CONN_SENT;
_members[memb] = m;
}
_myState = MEMB_STATE_INTRO_LIST;
if (clusterMembers.size() == 0) {
// no members to send to, thus we can complete now.
if(!checkAndPossiblyCompleteIntroduction()) {
fprintf(stderr, "Failed to complete introduction early despite no members to connect with\n");
}
}
}
void processIntroRequest(const ClusterRequest& req) {
const std::string sender = std::get<std::string>(req.req[1]);
if (_members.find(sender) != _members.end()) {
// to do handle re-intro of existing member
fprintf(stderr, "Got intro request for same member twice, %s\n", sender.c_str());
return;
}
Member m;
m.fd = req.conn.value();
m.state = MEMB_STATE_INTRO_REC;
_members[sender] = m;
ReqBuilder build;
build.pushU8(OPER_INTRO_LIST);
std::vector<std::string> myMembs;
listMembersToVec(myMembs);
build.pushStrList(myMembs);
build.putSize();
write(m.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
fprintf(stderr, "Got unexpected would block when sending intro resp to %s", sender.c_str());
exit(2);
}
// Once a beginning member acts as a cluster introducer, it considers itself connected.
if (_myState == MEMB_STATE_BEGIN)
_myState = MEMB_STATE_INTRO_PROCESSING;
}
void processRequests() {
for (const auto& req: _requests) {
// todo error handle
const unsigned char opCode = std::get<unsigned char>(req.req[0]);
switch (opCode) {
case OPER_INTRO:
processIntroRequest(req);
break;
case OPER_INTRO_LIST:
processIntroListRequest(req);
break;
case OPER_INTRO_CONN:
processIntroConnRequest(req);
break;
case OPER_INTRO_CONN_OK:
processIntroConnOkRequest(req);
break;
case OPER_INTRO_COMPLETE:
processIntroCompleteRequest(req);
break;
default:
fprintf(stderr, "Unknown request code %u\n", opCode);
exit(3);
}
}
_requests.clear();
}
void doWork() {
collectRequests();
processRequests();
// check state
}
int getState() const { return _myState; }
bool connectWith(const std::string& target) {
if (_myState != MEMB_STATE_BEGIN) {
return false;
}
if (_members.find(target) != _members.end()) {
return false;
}
std::optional<int> remote = create_client_socket(target.c_str());
if (!remote.has_value()) {
fprintf(stderr, "Failed to connect to %s, errno=%d", target.c_str(), getAndResetErrNo());
return false;
}
Member m;
m.fd = remote.value();
assert(set_non_blocking(m.fd, true));
m.state = MEMB_STATE_INTRO_SENT;
_members[target] = m;
ReqBuilder build;
build.pushU8(OPER_INTRO);
build.pushStr(_path);
build.putSize();
write(m.fd, build.getTotalReqData(), build.getTotalReqSize());
if (errNoIsWouldBlock()) {
fprintf(stderr, "Got unexpected would block when sending intro resp to %s", target.c_str());
return false;
}
_myState = MEMB_STATE_INTRO_SENT;
return true;
}
void debug() const {
printf("Myself: %s\n", _path.c_str());
printf("State: %d\n", _myState);
printf("isStarted: %s\n", _isStarted ? "true" : "false");
puts("-----------------------------------");
}
private:
friend class ClusterNodeTests;
bool _isStarted = false;
int _server = -1;
int _myState = MEMB_STATE_BEGIN;
std::string _path;
std::vector<int> _pending;
std::vector<ClusterRequest> _requests;
std::unordered_map<std::string, Member> _members;
};
//------- tests ---------
static unsigned _failures = 0;
static void check_cond(int cond, const char* condstr, unsigned line) {
if (!cond) {
fprintf(stderr, "Failed cond '%s' at line %u\n", condstr, line);
++_failures;
}
}
#define CHECKIT(cnd) check_cond(cnd, #cnd, __LINE__)
static void test_reqBuilder() {
ReqBuilder b;
b.pushU8(3);
std::string foo = "foo";
std::vector<std::string> foos = {"abc", "def"};
b.pushStr(foo);
b.pushStrList(foos);
b.putSize();
std::vector<ReqItem> elems;
buildFromBytes(elems, b.getReqData(), b.getReqSize());
CHECKIT(elems.size() == 3);
CHECKIT(std::get_if<unsigned char>(&elems[0]) != nullptr);
CHECKIT(std::get_if<std::string>(&elems[1]) != nullptr);
CHECKIT(std::get_if<std::vector<std::string>>(&elems[2]) != nullptr);
}
class ClusterNodeTests {
public:
static void testIntroRec();
};
void ClusterNodeTests::testIntroRec() {
ClusterNode n1("foo");
ClusterNode n2("bar");
n1.start();
n2.start();
CHECKIT(n1._requests.size() == 0);
CHECKIT(n2._requests.size() == 0);
CHECKIT(n1.connectWith(n2.getPath()));
CHECKIT(n1._members.size() == 1);
CHECKIT(n1._members.find(n2.getPath()) != n1._members.end());
CHECKIT(n1._myState == MEMB_STATE_INTRO_SENT);
n2.collectRequests();
CHECKIT(n2._requests.size() == 1);
CHECKIT(std::get_if<unsigned char>(&(n2._requests[0].req[0])) != nullptr);
CHECKIT(std::get_if<std::string>(&(n2._requests[0].req[1])) != nullptr);
CHECKIT(n2._pending.size() == 0);
CHECKIT(n2._members.size() == 0);
n2.processRequests();
CHECKIT(n2._requests.size() == 0);
CHECKIT(n2._members.size() == 1);
CHECKIT(n2._members.find(n1.getPath()) != n2._members.end());
CHECKIT(n2._members.find(n1.getPath())->second.state == MEMB_STATE_INTRO_REC);
CHECKIT(n2._myState == MEMB_STATE_INTRO_PROCESSING);
n1.collectRequests();
CHECKIT(n1._requests.size() == 1);
CHECKIT(std::get_if<std::vector<std::string>>(&(n1._requests[0].req[1])) != nullptr);
CHECKIT(n1._pending.size() == 0);
n1.processRequests();
CHECKIT(n1._myState == MEMB_STATE_INTRO_LIST);
}
int main(int argc, char const *argv[])
{
test_reqBuilder();
ClusterNodeTests::testIntroRec();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment