Skip to content

Instantly share code, notes, and snippets.

@risingnote
Created March 31, 2017 08:04
Show Gist options
  • Save risingnote/f5039cfa3140d4551ae53abf5b71ef6b to your computer and use it in GitHub Desktop.
Save risingnote/f5039cfa3140d4551ae53abf5b71ef6b to your computer and use it in GitHub Desktop.
TPMPC SPDZ MPC Voting Program
"""
Receive a share of votes for between 1 and 3 workshop talks, store in the
voter history (voter id, talk id * 3) and votes count (talk id, points total).
The voter history supports a voter recasting their votes.
The votes count supports calculating the top 3 talks which are revealed and
returned to the managmeent server.
At startup the management server (verified by public key) sends a list of shares
of valid voter ids and talk ids. The program is limited to a fixed array size of
voter ids and talk ids, with unused slots set to zero.
Voter shares are stored to disk and reloaded on restart, to provide persistence.
This mpc program is to be used in conjunction with the spdz-voting-gui.
"""
from Compiler.types import sint, regint, Array, Matrix
from Compiler.instructions import listen, acceptclientconnection, readsharesfromfile, writesharestofile
from Compiler.library import print_ln, if_, if_e, else_, for_range, do_while
from Compiler.util import if_else
# program is instance of Program class injected by compilerLib.run.
program.set_bit_length(32)
# Public key of management server, providing voter id data.
# Prod
AUTH_PUBLIC_KEY = '1da3dc974b0c04374c3b4797a522f50507b81584af653e08122db36fe178a521'
VOTING_PORTNUM = 14000
MGMT_SOCKET_ID = regint()
MAX_NUM_VOTERIDS = 100
MAX_NUM_TALKIDS = 40
def or_op(a, b):
"""Convience method to run or on sints"""
return a + b - a*b
def reformat_public_key(key_hex):
"""Convert authorised public key from hex string into list of 8 signed 32 bit ints,
ready for later validation."""
public_key_list = [key_hex[i:i+8] for i in range(0, len(key_hex), 8)]
def hex_to_signed(hex_int):
int_val = int(hex_int, 16)
if int_val > 0x7FFFFFFF:
int_val -= 0x100000000
return int_val
return map(hex_to_signed, public_key_list)
def mgmtserver_init():
"""High level control for management server setup of lookup data."""
validation = mgmtserver_connection()
# Use list to extract return values as mutable.
# Holds Array of voter ids, count of voter ids,
# Matrix[talk id][points], count of talk ids.
lookupdata = [Array(0, sint), sint(0), Matrix(0, 2, sint), sint(0)]
@if_e(validation != 8)
def not_valid():
print_ln('Management server public key validation failed.')
@else_
def read_for_lookup_data():
print_ln('Management server public key accepted.')
lookupdata[0], lookupdata[1] = setup_voter_data(MAX_NUM_VOTERIDS)
lookupdata[2], lookupdata[3] = setup_talk_data(MAX_NUM_TALKIDS)
return [validation] + lookupdata
def mgmtserver_connection():
"""Wait for socket connection from management server, read client public key
and authenticate. Return of 8 indicates all good."""
auth_public_key_ints = reformat_public_key(AUTH_PUBLIC_KEY)
acceptclientconnection(MGMT_SOCKET_ID, VOTING_PORTNUM)
public_key = regint.read_client_public_key(MGMT_SOCKET_ID)
validation_count = 0
for index, keypart in enumerate(public_key):
matches = keypart == auth_public_key_ints[index]
validation_count += if_else(matches, 1, 0)
return validation_count
def setup_voter_data(list_size):
"""Read for share of voter ids list"""
id_list = sint.receive_from_client(list_size, MGMT_SOCKET_ID)
id_array = Array(list_size, sint)
id_count = regint(0)
for i in range(list_size):
is_valid_id = id_list[i] != 0
id_count = if_else(is_valid_id, i + 1, id_count)
id_array[i] = id_list[i]
return id_array, id_count
def setup_talk_data(list_size):
"""Read for share of talk ids list"""
id_list = sint.receive_from_client(list_size, MGMT_SOCKET_ID)
# to hold talk id, points (initally 0)
votes_count = Matrix(list_size, 2, sint)
votes_count.assign_all(sint(0))
id_count = regint(0)
for i in range(list_size):
is_valid_id = id_list[i] != 0
id_count = if_else(is_valid_id, i + 1, id_count)
votes_count[i][0] = id_list[i]
return votes_count, id_count
def load_voter_history(votes_count):
"""Load the shares of voter history from disk at startup."""
voter_history = Matrix(MAX_NUM_VOTERIDS, 4, sint)
voter_history.assign_all(sint(0))
startposn = Array(1, regint)
startposn[0] = 0
print_ln("Loading voter history.")
@do_while
def read_input_file_loop():
endposn = regint()
# Read 4 shares at a time.
vote_list = [sint() for i in range(4)]
readsharesfromfile(startposn[0], endposn, *vote_list)
# First time file may not exist so chk
@if_(endposn != -2)
def process_history_vote():
startposn[0] = endposn
# Process shares as if input by user to build up history
vote_array = Array(4, sint)
for i, votepart in enumerate(vote_list):
vote_array[i] = votepart
apply_votes(vote_array, voter_history, votes_count)
return endposn > 0
return voter_history
def accept_votes(voter_history, votes_count, voter_ids):
"""Listen in while loop for client connections, validate and process votes."""
@do_while
def wait_votes_loop():
client_socket_id = regint()
print_ln('Waiting for voter data....')
acceptclientconnection(client_socket_id, VOTING_PORTNUM)
regint.read_client_public_key(client_socket_id)
# Expect voterid, talk id, talk id | 0, talk id | 0.
vote_list = sint.receive_from_client(4, client_socket_id)
vote_array = Array(4, sint)
for i, votepart in enumerate(vote_list):
vote_array[i] = votepart
print_ln('Got voter data.')
# Validate (1 good, 2 bad voterid, 3 bad talk id)
result = Array(1, regint)
result[0] = validate_voter_id(vote_array[0], voter_ids)
@if_(result[0] == 1)
def validate_talk():
result[0] = validate_talk_ids(vote_array, votes_count)
print_ln('Voter validation returns %s.', result[0])
regint.write_to_socket(client_socket_id, result)
# Close socket handled by client
@if_(result[0] == 1)
def accept_vote():
apply_votes(vote_array, voter_history, votes_count)
calculate_send_result(votes_count)
# Persist to file to enable restarts.
writesharestofile(*vote_array)
return True
def validate_voter_id(input, voter_ids):
"""Does the input appear in the list.
Input of zero is always invalid.
Returns 2 for valid, 1 for not."""
is_valid = Array(1, sint)
is_valid[0] = input > 0
@if_(is_valid[0].reveal() == 1)
def chk_in_list():
is_valid[0] = sint(0)
@for_range(len(voter_ids))
def loop_body(i):
is_valid[0] = or_op(is_valid[0], input == voter_ids[i])
return if_else(is_valid[0], 1, 2).reveal()
def validate_talk_ids(vote_array, votes_count):
"""First vote must be non zero.
Not currently checking that talk ids exist in votes_count[i][0]."""
is_valid = vote_array[1] > 0
return if_else(is_valid, 1, 3).reveal()
def apply_votes(vote_array, voter_history, votes_count):
"""Insert/replace a users votes both in voter history and votes count.
Here no prior knowledge about voter history index posn, revealing
does not reveal the voter id, only that this is an updating vote."""
# TODO could combine vote_array and voter_history if chg logic to not reveal array index.
matchIndex = Array(1, sint)
matchIndex[0] = sint(-1)
emptyIndex = Array(1, sint)
emptyIndex[0] = sint(-1)
# Search for both match on vote id and last empty (0) position
@for_range(len(voter_history))
def loop_body(i):
is_match = vote_array[0] == voter_history[i][0]
matchIndex[0] = if_else(is_match, i, matchIndex[0])
is_empty = sint(0) == voter_history[i][0]
emptyIndex[0] = if_else(is_empty, i, emptyIndex[0])
replace_index = regint(matchIndex[0].reveal())
insert_index = regint(emptyIndex[0].reveal())
@if_e(replace_index >= 0)
def updating_existing():
update_summary_all(voter_history[replace_index], votes_count, -1)
@for_range(1,4)
def loop_body(i):
voter_history[replace_index][i] = vote_array[i]
print_ln('Vote history updated.')
@else_
def adding_new():
@for_range(0,4)
def loop_body(i):
voter_history[insert_index][i] = vote_array[i]
print_ln('Vote history inserted.')
update_summary_all(vote_array, votes_count, 1)
def update_summary_all(vote_array, votes_count, operator):
"""Adjust summary of talk id votes with vote_array.
Give 3 votes for first talk id, 2 for second, 1 for third.
Use operator to remove or add votes.
Note 2nd and 3rd talk ids can be 0 in which case ignored."""
# Talk id already valdiated to be non zero, so don't check
update_summary(vote_array[1], votes_count, operator*3)
@if_((vote_array[2] > 0).reveal() == 1)
def body2():
update_summary(vote_array[2], votes_count, operator*2)
@if_((vote_array[3] > 0).reveal() == 1)
def body3():
update_summary(vote_array[3], votes_count, operator*1)
def update_summary(talk_id, votes_count, points):
"""Find the matching talk id in votes_count and update with points.
Note votes_count will have been initialised with all talk ids."""
# TODO Improvement is to use if_else to apply updates on every row - either points or 0.
match_index = Array(1, sint)
match_index[0] = sint(-1)
@for_range(len(votes_count))
def loop_body(i):
is_match = talk_id == votes_count[i][0]
match_index[0] = if_else(is_match, i, match_index[0])
index = regint(match_index[0].reveal())
@if_e(index >= 0)
def apply_value():
votes_count[index][1] += points
@else_
def invalid_talk_id():
print_ln('Got invalid talk id, not adding to totals.')
def calculate_send_result(votes_count):
"""Given the talk ids and points totals calculate the top 3 in order.
Send the result to the mamagement server."""
current_top = Array(6, sint)
current_top.assign_all(sint(0))
@for_range(len(votes_count))
def loop_body(i):
update_top_three(votes_count[i], current_top)
results = Array(6, regint)
@for_range(len(results))
def transform_results(i):
results[i] = current_top[i].reveal()
print_ln('Top 3 are: %s %s, %s %s, %s %s', results[0], results[1], results[2],
results[3], results[4], results[5])
regint.write_to_socket(MGMT_SOCKET_ID, results)
def update_top_three(curr_vote, top_three):
"""Compare curr_vote against top votes, shuffle and
insert if greater number of points."""
# TODO Improvement to avoid revealing order for every vote summary passed in
def shuffle_down(stop):
"""Shuffle pairs of values down an array."""
@for_range(3, stop-1, -1)
def loop_body(i):
top_three[i+2] = top_three[i]
def insert_vote(position):
"""Shuffle and insert vote at position given."""
shuffle_down(position)
top_three[position] = curr_vote[0]
top_three[position+1] = curr_vote[1]
@if_e((curr_vote[1] > top_three[1]).reveal() == 1)
def insert_as_top_vote():
insert_vote(0)
@else_
def not_the_top_vote():
@if_e((curr_vote[1] > top_three[3]).reveal() == 1)
def insert_as_second_vote():
insert_vote(2)
@else_
def not_the_second_vote():
@if_((curr_vote[1] > top_three[5]).reveal() == 1)
def insert_as_third_vote():
insert_vote(4)
def main():
listen(VOTING_PORTNUM)
print_ln('Listening for client connections on base port %s', VOTING_PORTNUM)
pub_key_valid, voter_ids, voter_id_count, votes_count, talk_id_count = mgmtserver_init()
@if_e(pub_key_valid != 8)
def stop_processing():
print_ln('Unable to continue, setup failed.')
@else_
def manage_voting():
voter_history = load_voter_history(votes_count)
print_ln('Setup finished got %s voter ids, %s talk ids.', voter_id_count.reveal(),
talk_id_count.reveal())
calculate_send_result(votes_count)
accept_votes(voter_history, votes_count, voter_ids)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment