Created
March 31, 2017 08:04
-
-
Save risingnote/f5039cfa3140d4551ae53abf5b71ef6b to your computer and use it in GitHub Desktop.
TPMPC SPDZ MPC Voting Program
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Receive a share of votes for between 1 and 3 workshop talks, store in the | |
voter history (voter id, talk id * 3) and votes count (talk id, points total). | |
The voter history supports a voter recasting their votes. | |
The votes count supports calculating the top 3 talks which are revealed and | |
returned to the managmeent server. | |
At startup the management server (verified by public key) sends a list of shares | |
of valid voter ids and talk ids. The program is limited to a fixed array size of | |
voter ids and talk ids, with unused slots set to zero. | |
Voter shares are stored to disk and reloaded on restart, to provide persistence. | |
This mpc program is to be used in conjunction with the spdz-voting-gui. | |
""" | |
from Compiler.types import sint, regint, Array, Matrix | |
from Compiler.instructions import listen, acceptclientconnection, readsharesfromfile, writesharestofile | |
from Compiler.library import print_ln, if_, if_e, else_, for_range, do_while | |
from Compiler.util import if_else | |
# program is instance of Program class injected by compilerLib.run. | |
program.set_bit_length(32) | |
# Public key of management server, providing voter id data. | |
# Prod | |
AUTH_PUBLIC_KEY = '1da3dc974b0c04374c3b4797a522f50507b81584af653e08122db36fe178a521' | |
VOTING_PORTNUM = 14000 | |
MGMT_SOCKET_ID = regint() | |
MAX_NUM_VOTERIDS = 100 | |
MAX_NUM_TALKIDS = 40 | |
def or_op(a, b): | |
"""Convience method to run or on sints""" | |
return a + b - a*b | |
def reformat_public_key(key_hex): | |
"""Convert authorised public key from hex string into list of 8 signed 32 bit ints, | |
ready for later validation.""" | |
public_key_list = [key_hex[i:i+8] for i in range(0, len(key_hex), 8)] | |
def hex_to_signed(hex_int): | |
int_val = int(hex_int, 16) | |
if int_val > 0x7FFFFFFF: | |
int_val -= 0x100000000 | |
return int_val | |
return map(hex_to_signed, public_key_list) | |
def mgmtserver_init(): | |
"""High level control for management server setup of lookup data.""" | |
validation = mgmtserver_connection() | |
# Use list to extract return values as mutable. | |
# Holds Array of voter ids, count of voter ids, | |
# Matrix[talk id][points], count of talk ids. | |
lookupdata = [Array(0, sint), sint(0), Matrix(0, 2, sint), sint(0)] | |
@if_e(validation != 8) | |
def not_valid(): | |
print_ln('Management server public key validation failed.') | |
@else_ | |
def read_for_lookup_data(): | |
print_ln('Management server public key accepted.') | |
lookupdata[0], lookupdata[1] = setup_voter_data(MAX_NUM_VOTERIDS) | |
lookupdata[2], lookupdata[3] = setup_talk_data(MAX_NUM_TALKIDS) | |
return [validation] + lookupdata | |
def mgmtserver_connection(): | |
"""Wait for socket connection from management server, read client public key | |
and authenticate. Return of 8 indicates all good.""" | |
auth_public_key_ints = reformat_public_key(AUTH_PUBLIC_KEY) | |
acceptclientconnection(MGMT_SOCKET_ID, VOTING_PORTNUM) | |
public_key = regint.read_client_public_key(MGMT_SOCKET_ID) | |
validation_count = 0 | |
for index, keypart in enumerate(public_key): | |
matches = keypart == auth_public_key_ints[index] | |
validation_count += if_else(matches, 1, 0) | |
return validation_count | |
def setup_voter_data(list_size): | |
"""Read for share of voter ids list""" | |
id_list = sint.receive_from_client(list_size, MGMT_SOCKET_ID) | |
id_array = Array(list_size, sint) | |
id_count = regint(0) | |
for i in range(list_size): | |
is_valid_id = id_list[i] != 0 | |
id_count = if_else(is_valid_id, i + 1, id_count) | |
id_array[i] = id_list[i] | |
return id_array, id_count | |
def setup_talk_data(list_size): | |
"""Read for share of talk ids list""" | |
id_list = sint.receive_from_client(list_size, MGMT_SOCKET_ID) | |
# to hold talk id, points (initally 0) | |
votes_count = Matrix(list_size, 2, sint) | |
votes_count.assign_all(sint(0)) | |
id_count = regint(0) | |
for i in range(list_size): | |
is_valid_id = id_list[i] != 0 | |
id_count = if_else(is_valid_id, i + 1, id_count) | |
votes_count[i][0] = id_list[i] | |
return votes_count, id_count | |
def load_voter_history(votes_count): | |
"""Load the shares of voter history from disk at startup.""" | |
voter_history = Matrix(MAX_NUM_VOTERIDS, 4, sint) | |
voter_history.assign_all(sint(0)) | |
startposn = Array(1, regint) | |
startposn[0] = 0 | |
print_ln("Loading voter history.") | |
@do_while | |
def read_input_file_loop(): | |
endposn = regint() | |
# Read 4 shares at a time. | |
vote_list = [sint() for i in range(4)] | |
readsharesfromfile(startposn[0], endposn, *vote_list) | |
# First time file may not exist so chk | |
@if_(endposn != -2) | |
def process_history_vote(): | |
startposn[0] = endposn | |
# Process shares as if input by user to build up history | |
vote_array = Array(4, sint) | |
for i, votepart in enumerate(vote_list): | |
vote_array[i] = votepart | |
apply_votes(vote_array, voter_history, votes_count) | |
return endposn > 0 | |
return voter_history | |
def accept_votes(voter_history, votes_count, voter_ids): | |
"""Listen in while loop for client connections, validate and process votes.""" | |
@do_while | |
def wait_votes_loop(): | |
client_socket_id = regint() | |
print_ln('Waiting for voter data....') | |
acceptclientconnection(client_socket_id, VOTING_PORTNUM) | |
regint.read_client_public_key(client_socket_id) | |
# Expect voterid, talk id, talk id | 0, talk id | 0. | |
vote_list = sint.receive_from_client(4, client_socket_id) | |
vote_array = Array(4, sint) | |
for i, votepart in enumerate(vote_list): | |
vote_array[i] = votepart | |
print_ln('Got voter data.') | |
# Validate (1 good, 2 bad voterid, 3 bad talk id) | |
result = Array(1, regint) | |
result[0] = validate_voter_id(vote_array[0], voter_ids) | |
@if_(result[0] == 1) | |
def validate_talk(): | |
result[0] = validate_talk_ids(vote_array, votes_count) | |
print_ln('Voter validation returns %s.', result[0]) | |
regint.write_to_socket(client_socket_id, result) | |
# Close socket handled by client | |
@if_(result[0] == 1) | |
def accept_vote(): | |
apply_votes(vote_array, voter_history, votes_count) | |
calculate_send_result(votes_count) | |
# Persist to file to enable restarts. | |
writesharestofile(*vote_array) | |
return True | |
def validate_voter_id(input, voter_ids): | |
"""Does the input appear in the list. | |
Input of zero is always invalid. | |
Returns 2 for valid, 1 for not.""" | |
is_valid = Array(1, sint) | |
is_valid[0] = input > 0 | |
@if_(is_valid[0].reveal() == 1) | |
def chk_in_list(): | |
is_valid[0] = sint(0) | |
@for_range(len(voter_ids)) | |
def loop_body(i): | |
is_valid[0] = or_op(is_valid[0], input == voter_ids[i]) | |
return if_else(is_valid[0], 1, 2).reveal() | |
def validate_talk_ids(vote_array, votes_count): | |
"""First vote must be non zero. | |
Not currently checking that talk ids exist in votes_count[i][0].""" | |
is_valid = vote_array[1] > 0 | |
return if_else(is_valid, 1, 3).reveal() | |
def apply_votes(vote_array, voter_history, votes_count): | |
"""Insert/replace a users votes both in voter history and votes count. | |
Here no prior knowledge about voter history index posn, revealing | |
does not reveal the voter id, only that this is an updating vote.""" | |
# TODO could combine vote_array and voter_history if chg logic to not reveal array index. | |
matchIndex = Array(1, sint) | |
matchIndex[0] = sint(-1) | |
emptyIndex = Array(1, sint) | |
emptyIndex[0] = sint(-1) | |
# Search for both match on vote id and last empty (0) position | |
@for_range(len(voter_history)) | |
def loop_body(i): | |
is_match = vote_array[0] == voter_history[i][0] | |
matchIndex[0] = if_else(is_match, i, matchIndex[0]) | |
is_empty = sint(0) == voter_history[i][0] | |
emptyIndex[0] = if_else(is_empty, i, emptyIndex[0]) | |
replace_index = regint(matchIndex[0].reveal()) | |
insert_index = regint(emptyIndex[0].reveal()) | |
@if_e(replace_index >= 0) | |
def updating_existing(): | |
update_summary_all(voter_history[replace_index], votes_count, -1) | |
@for_range(1,4) | |
def loop_body(i): | |
voter_history[replace_index][i] = vote_array[i] | |
print_ln('Vote history updated.') | |
@else_ | |
def adding_new(): | |
@for_range(0,4) | |
def loop_body(i): | |
voter_history[insert_index][i] = vote_array[i] | |
print_ln('Vote history inserted.') | |
update_summary_all(vote_array, votes_count, 1) | |
def update_summary_all(vote_array, votes_count, operator): | |
"""Adjust summary of talk id votes with vote_array. | |
Give 3 votes for first talk id, 2 for second, 1 for third. | |
Use operator to remove or add votes. | |
Note 2nd and 3rd talk ids can be 0 in which case ignored.""" | |
# Talk id already valdiated to be non zero, so don't check | |
update_summary(vote_array[1], votes_count, operator*3) | |
@if_((vote_array[2] > 0).reveal() == 1) | |
def body2(): | |
update_summary(vote_array[2], votes_count, operator*2) | |
@if_((vote_array[3] > 0).reveal() == 1) | |
def body3(): | |
update_summary(vote_array[3], votes_count, operator*1) | |
def update_summary(talk_id, votes_count, points): | |
"""Find the matching talk id in votes_count and update with points. | |
Note votes_count will have been initialised with all talk ids.""" | |
# TODO Improvement is to use if_else to apply updates on every row - either points or 0. | |
match_index = Array(1, sint) | |
match_index[0] = sint(-1) | |
@for_range(len(votes_count)) | |
def loop_body(i): | |
is_match = talk_id == votes_count[i][0] | |
match_index[0] = if_else(is_match, i, match_index[0]) | |
index = regint(match_index[0].reveal()) | |
@if_e(index >= 0) | |
def apply_value(): | |
votes_count[index][1] += points | |
@else_ | |
def invalid_talk_id(): | |
print_ln('Got invalid talk id, not adding to totals.') | |
def calculate_send_result(votes_count): | |
"""Given the talk ids and points totals calculate the top 3 in order. | |
Send the result to the mamagement server.""" | |
current_top = Array(6, sint) | |
current_top.assign_all(sint(0)) | |
@for_range(len(votes_count)) | |
def loop_body(i): | |
update_top_three(votes_count[i], current_top) | |
results = Array(6, regint) | |
@for_range(len(results)) | |
def transform_results(i): | |
results[i] = current_top[i].reveal() | |
print_ln('Top 3 are: %s %s, %s %s, %s %s', results[0], results[1], results[2], | |
results[3], results[4], results[5]) | |
regint.write_to_socket(MGMT_SOCKET_ID, results) | |
def update_top_three(curr_vote, top_three): | |
"""Compare curr_vote against top votes, shuffle and | |
insert if greater number of points.""" | |
# TODO Improvement to avoid revealing order for every vote summary passed in | |
def shuffle_down(stop): | |
"""Shuffle pairs of values down an array.""" | |
@for_range(3, stop-1, -1) | |
def loop_body(i): | |
top_three[i+2] = top_three[i] | |
def insert_vote(position): | |
"""Shuffle and insert vote at position given.""" | |
shuffle_down(position) | |
top_three[position] = curr_vote[0] | |
top_three[position+1] = curr_vote[1] | |
@if_e((curr_vote[1] > top_three[1]).reveal() == 1) | |
def insert_as_top_vote(): | |
insert_vote(0) | |
@else_ | |
def not_the_top_vote(): | |
@if_e((curr_vote[1] > top_three[3]).reveal() == 1) | |
def insert_as_second_vote(): | |
insert_vote(2) | |
@else_ | |
def not_the_second_vote(): | |
@if_((curr_vote[1] > top_three[5]).reveal() == 1) | |
def insert_as_third_vote(): | |
insert_vote(4) | |
def main(): | |
listen(VOTING_PORTNUM) | |
print_ln('Listening for client connections on base port %s', VOTING_PORTNUM) | |
pub_key_valid, voter_ids, voter_id_count, votes_count, talk_id_count = mgmtserver_init() | |
@if_e(pub_key_valid != 8) | |
def stop_processing(): | |
print_ln('Unable to continue, setup failed.') | |
@else_ | |
def manage_voting(): | |
voter_history = load_voter_history(votes_count) | |
print_ln('Setup finished got %s voter ids, %s talk ids.', voter_id_count.reveal(), | |
talk_id_count.reveal()) | |
calculate_send_result(votes_count) | |
accept_votes(voter_history, votes_count, voter_ids) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment