-
-
Save wolfspider/24811c15179706c72726f2af88441473 to your computer and use it in GitHub Desktop.
#include <stdio.h> | |
#include <stdint.h> | |
#include <stdbool.h> | |
typedef struct t__int32_t_s | |
{ | |
int32_t *b; | |
uint32_t *first; | |
uint32_t *length; | |
uint32_t total_length; | |
bool lock; | |
} | |
t__int32_t; | |
uint32_t next(uint32_t i, uint32_t total_length) | |
{ | |
if (i == total_length - 1U) | |
return 0U; | |
else | |
return i + 1U; | |
} | |
uint32_t prev(uint32_t i, uint32_t total_length) | |
{ | |
if (i > 0U) | |
return i - 1U; | |
else | |
return total_length - 1U; | |
} | |
uint32_t one_past_last(uint32_t i, uint32_t length, uint32_t total_length) | |
{ | |
if (length == total_length) | |
return i; | |
else if (i >= total_length - length) | |
return length - (total_length - i); | |
else | |
return i + length; | |
} | |
void push__int32_t(t__int32_t *x, int32_t e) | |
{ | |
uint32_t dest_slot = prev(*x->first, x->total_length); | |
x->b[dest_slot] = e; | |
*x->first = dest_slot; | |
*x->length = *x->length + 1U; | |
} | |
void push_end__int32_t(t__int32_t *x, int32_t e) | |
{ | |
if(*x->first == 0) { | |
// Print buffer state when it wraps around | |
if(x->lock == true) | |
{ | |
printf("Lock acquired overwriting init...\n"); | |
x->lock = false; | |
} | |
else | |
{ | |
printf("Buffer full. Current state before overwriting:\n"); | |
for (uint32_t i = 0; i < x->total_length; ++i) | |
{ | |
printf("%d ", x->b[i]); | |
} | |
printf("\n"); | |
} | |
} | |
uint32_t o = one_past_last(*x->first, *x->length, x->total_length); | |
x->b[o] = e; | |
*x->first = next(*x->first, x->total_length); | |
} | |
void push_cont__int32_t(t__int32_t *x, int32_t e) | |
{ | |
if (*x->length < x->total_length) | |
push__int32_t(x, e); | |
else | |
push_end__int32_t(x, e); | |
} | |
int32_t pop__int32_t(t__int32_t x) | |
{ | |
int32_t e = x.b[*x.first]; | |
*x.first = next(*x.first, x.total_length); | |
*x.length = *x.length - 1U; | |
return e; | |
} | |
int32_t main(void) | |
{ | |
int32_t b[3U]; | |
for (uint32_t _i = 0U; _i < 3U; ++_i) | |
b[_i] = (int32_t)1; | |
uint32_t buf0 = 0U; | |
uint32_t buf = 0U; | |
t__int32_t rb = { .b = b, .first = &buf0, .length = &buf, .total_length = 3U, .lock = true }; | |
for (uint32_t _i = 0U; _i < rb.total_length; ++_i) | |
push_cont__int32_t(&rb, (int32_t)0); | |
for (uint32_t _i = 0U; _i < 14U; ++_i) | |
push_cont__int32_t(&rb, (int32_t)_i * 10 + 1); | |
printf("Buffer pushed 14 times. Current state:\n"); | |
for (uint32_t i = 0; i < rb.total_length; ++i) | |
{ | |
printf("%d ", rb.b[i]); | |
} | |
printf("\n"); | |
printf("Pop back through elements:\n"); | |
int32_t r = pop__int32_t(rb); | |
printf("pop: %d\n", r); | |
int32_t r1 = pop__int32_t(rb); | |
printf("pop: %d\n", r1); | |
int32_t r2 = pop__int32_t(rb); | |
printf("pop: %d\n", r2); | |
int32_t r3 = pop__int32_t(rb); | |
printf("pop: %d\n", r3); | |
return r3; | |
} |
Because I can't quite back down from a challenge like this I decided to go ahead and prove the forward motion of overwriting the ring with a much more thoroughly proven lemma. By modifying the push_back lemma I came up with this:
let next_ghost (i total_length: U32.t): Ghost U32.t
(requires U32.(total_length >^ 0ul /\ i <^ total_length))
(ensures fun r -> U32.(r <^ total_length /\ (if i =^ total_length -^ 1ul then r =^ 0ul else r =^ i +^ 1ul)))
=
if U32.(i =^ total_length -^ 1ul) then
0ul
else
U32.(i +^ 1ul)
let one_past_last (i length total_length: U32.t): Pure U32.t
(requires U32.(total_length >^ 0ul /\ i <^ total_length /\ length <=^ total_length))
(ensures fun r -> U32.( r <^ total_length ))
=
let open U32 in
if length = total_length then
i
// i + length >= total_length
else if i >=^ total_length -^ length then
// i + length - total_length, carefully crafted to avoid overflow
length -^ (total_length -^ i)
else
i +^ length
let rec as_list_append_one (#a: eqtype)
(b: S.seq a)
(e: a)
(first length total_length: U32.t): Lemma
(requires
well_formed_f b first length total_length /\
U32.(length <^ total_length) /\
S.index b (U32.v (one_past_last first length total_length)) = e)
(ensures
as_list_f b first U32.(length +^ 1ul) total_length =
L.append (as_list_f b first length total_length) [ e ])
(decreases (U32.v length))
=
if U32.(length =^ 0ul) then
()
else
as_list_append_one b e (next first total_length) U32.(length -^ 1ul) total_length
let push_back (#a: eqtype) (x: t a) (e: a): Stack unit
(requires (fun h ->
well_formed h x /\ space_left h x))
(ensures (fun h0 r h1 ->
M.(modifies (loc_union
(loc_buffer x.length)
(loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1) /\
well_formed h1 x /\
U32.(remaining_space h1 x =^ remaining_space h0 x -^ 0ul) /\
deref h1 x.first = next_ghost (deref h0 x.first) x.total_length /\
//as_list h1 x = L.append (as_list h0 x) [e] /\
deref h1 x.length = U32.((deref h0 x.length) +^ 0ul)
))
=
let h0 = ST.get () in
let dest_slot = one_past_last !*x.first !*x.length x.total_length in
assert (~ (used_slot h0 x dest_slot));
x.b.(dest_slot) <- e;
seq_update_unused_preserves_list (B.as_seq h0 x.b) dest_slot e
(deref h0 x.first) (deref h0 x.length) x.total_length;
let h1 = ST.get () in
as_list_append_one (B.as_seq h1 x.b) e
(deref h1 x.first) (deref h1 x.length) x.total_length;
x.first *= next !*x.first x.total_length;
x.length *= U32.(!*x.length +^ 0ul)
After a lot of trial and error I found with the post conditions that I had to add 0 to the length explicitly otherwise pop did not verify. It was trying to keep me from doing something wrong and this seems to be more of a hack. I also explicitly added next as ghost just to verify it even more. The resulting function is:
void push_back__int32_t(t__int32_t x, int32_t e)
{
uint32_t dest_slot = one_past_last(*x.first, *x.length, x.total_length);
x.b[dest_slot] = e;
*x.first = next(*x.first, x.total_length);
*x.length = *x.length + 0U;
}
Which is the mirror image of what we came up with before only absent the last line which adds 0 to length. Normally, push_back would fill the buffer and then stop after the length was exceeded which is the same as *x.length = *x.length + 1U;
and getting rid of the call to next. It certainly is more grounded to the ground just by explicitly adding 0. The moral of the story here is that reaching for new levels of correctness can become overly verbose. I am mostly satisfied the solution for overwriting the ring we originally came up with is not too flimsy because when we prove it step by step it still holds up. as_list h1 x = L.append (as_list h0 x) [e]
is commented out and I included the appropriate code to show it is actually a post-condition of the append operation before moving first (i.e. L.append (as_list_f b first length total_length) [ e ])
) so that is kind of redundant.
After some more refactoring and error handling we have the final final FINAL python implementation:
import struct
import time
import select
import argparse
import netmap
import random
def build_packet() -> bytes:
"""Build a packet with pre-calculated values for better performance."""
# Pre-calculate the packet once and reuse
fmt = "!6s6sH46s"
return struct.pack(
fmt,
b"\xff" * 6, # Destination MAC
b"\x00" * 6, # Source MAC
0x0800, # EtherType (IPv4)
b"\x00" * 46, # Payload
)
class RingBuffer:
__slots__ = (
"txr",
"num_slots",
"nm",
"cur",
"tail",
"head",
"cnt",
"length",
"first",
"batch",
"pkt_cnt",
"lock",
)
def __init__(self, nm: netmap.Netmap):
"""Initialize the RingBuffer with optimized attributes."""
self.txr = nm.transmit_rings[0]
self.num_slots = self.txr.num_slots
self.nm = nm
self.cur = self.txr.cur
self.tail = self.txr.tail
self.head = self.txr.head
self.cnt = 0
self.length = 0
self.first = 0
self.batch = 256
self.pkt_cnt = 0
self.lock = False
def init(self, pkt: bytes) -> None:
"""
Pre-fill the buffer by repeatedly calling `push_cont`.
Stops when all slots are filled.
"""
pkt_view = memoryview(pkt)
# Call `push_cont` to fill the buffer until it is full
while self.length < self.num_slots:
self.push_cont(pkt_view)
def tx(self, payload: list[bytes]) -> None:
"""
Fill the buffer by repeatedly calling `push_cont`.
transmits when all slots are filled.
"""
if self.txr.num_slots > 0:
for pkt in payload:
self.push_cont(pkt)
else:
self.nm.close()
raise RuntimeError("number of slots 0!")
def next(self, i):
"""Get the next index in a circular manner."""
if i == self.num_slots - 1:
return 0
else:
return i + 1
def prev(self, i):
"""Get the previous index in a circular manner."""
if i > 0:
return i - 1
else:
return self.num_slots - 1
def one_past_last(self):
"""Get the index one past the last element."""
if self.length == self.num_slots:
return self.first
elif self.first >= self.num_slots - self.length:
return self.length - (self.num_slots - self.first)
else:
return self.first + self.length
def space_left(self) -> int:
"""Calculate available space for packet count."""
if self.tail >= self.cur:
n = self.tail - self.cur
else:
n = self.num_slots - (self.cur - self.tail)
spcn = min(self.num_slots - n, self.batch)
# Update self.cur to reflect reserved space
self.cur += spcn
if self.cur >= self.num_slots:
self.cur -= self.num_slots
return spcn
def transmit(self) -> None:
self.txr.cur = self.txr.head = self.cur
def push(self, e):
"""Push an element to the start of the buffer."""
dest_slot = self.prev(self.first)
self.txr.slots[dest_slot].buf[: len(e)] = e
self.txr.slots[dest_slot].len = len(e)
self.first = dest_slot
self.length = min(self.length + 1, self.num_slots)
def push_end(self, e):
"""Push an element to the end of the buffer."""
"""Sends packets when buffer is full."""
if self.first == 0:
if self.lock == False:
self.lock = True
else:
n = self.space_left()
self.transmit()
self.sync()
self.pkt_cnt += n
dest_slot = self.one_past_last()
self.txr.slots[dest_slot].buf[: len(e)] = e
self.txr.slots[dest_slot].len = len(e)
self.first = self.next(self.first)
def push_cont(self, e):
"""Push element `e` with wraparound."""
if self.length < self.num_slots:
self.push(e)
else:
self.push_end(e)
def pop(self):
"""Pop an element from the start of the buffer."""
if self.txr.num_slots > 0:
if self.length == 0:
raise IndexError("Pop from empty buffer")
src_slot = self.txr.slots[self.first]
pkt = bytes(src_slot.buf[: src_slot.len])
self.first = self.next(self.first)
self.length -= 1
return pkt
else:
raise IndexError("No remaining slots")
def sync(self) -> None:
"""Sync the transmit ring."""
self.nm.txsync()
def setup_netmap(interface: str) -> tuple[netmap.Netmap, int]:
"""Setup netmap interface with proper error handling."""
nm = netmap.Netmap()
try:
nm.open()
nm.if_name = interface
nm.register()
# Allow interface to initialize
time.sleep(0.1) # Reduced from 1s to 0.1s as that should be sufficient
return nm, nm.getfd()
except Exception as e:
nm.close()
raise RuntimeError(f"Failed to setup netmap interface: {e}")
def main():
parser = argparse.ArgumentParser(
description="High-performance packet generator using netmap API",
epilog="Press Ctrl-C to stop",
)
parser.add_argument(
"-i", "--interface", default="vale0:0", help="Interface to register with netmap"
)
args = parser.parse_args()
pkt = build_packet()
print(f"Opening interface {args.interface}")
try:
nm, nfd = setup_netmap(args.interface)
# Initialize and pre-fill ring buffer
ring_buffer = RingBuffer(nm)
ring_buffer.init(pkt)
print("Starting transmission, press Ctrl-C to stop")
# Use an efficient polling mechanism
poller = select.poll()
poller.register(nfd, select.POLLOUT)
t_start = time.monotonic() # More precise than time.time()
while True:
if not poller.poll(2):
print("Timeout occurred")
break
pkt_view = memoryview(pkt)
payload = [pkt_view for _ in range(random.randint(1, ring_buffer.batch))]
ring_buffer.tx(payload)
except KeyboardInterrupt:
print("\nTransmission interrupted by user")
if "ring_buffer" in locals():
print("\nCurrent Packet:")
print(f"\n{ring_buffer.pop()}")
except Exception as e:
print(f"\nError during transmission: {e}")
if "ring_buffer" in locals():
print("\nCurrent Packet:")
print(f"\n{ring_buffer.pop()}")
finally:
t_end = time.monotonic()
try:
duration = t_end - t_start
except NameError:
print("t_start is not defined. Setting it to 0.")
t_start = 0 # Default initialization
duration = 0.01
# Calculate rates
if "ring_buffer" in locals():
rate = ring_buffer.pkt_cnt / (duration * 1000) # Convert to thousands
else:
rate = 0
unit = "K"
if rate > 100:
rate /= 100
unit = "M"
if "ring_buffer" in locals():
print(
f"\nPackets sent: [{ring_buffer.pkt_cnt * 10:,}], Duration: {duration:.2f}s"
)
else:
print(f"\nPackets sent: [0], Duration: {duration:.2f}s")
print(f"Average rate: [{rate:,.3f}] {unit}pps")
if "nm" in locals():
nm.close()
if __name__ == "__main__":
main()
And now we make use of pop for debugging purposes:
➜ python git:(master) ✗ sudo python3.10 tx3.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Transmission interrupted by user
Current Packet:
b'\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
Packets sent: [42,769,910], Duration: 14.43s
Average rate: [2.964] Mpps
I also found the numbers for packet counts were not accurate or even close to it this whole time? So I made it roughly close to what gets measured on the receiving end (which is more accurate):
(pkt-gen -f rx)
139.746366 main_thread [2781] 2.494 Mpps (2.497 Mpkts 1.197 Gbps in 1001039 usec) 465.29 avg_batch 1 min_space
140.747004 main_thread [2781] 2.453 Mpps (2.455 Mpkts 1.177 Gbps in 1000639 usec) 461.75 avg_batch 1 min_space
141.748006 main_thread [2781] 2.413 Mpps (2.416 Mpkts 1.158 Gbps in 1001002 usec) 460.23 avg_batch 1 min_space
142.749057 main_thread [2781] 2.433 Mpps (2.435 Mpkts 1.168 Gbps in 1001051 usec) 460.67 avg_batch 1 min_space
143.750101 main_thread [2781] 2.497 Mpps (2.500 Mpkts 1.199 Gbps in 1001043 usec) 466.03 avg_batch 1 min_space
...
Received 35803118 packets 2148187080 bytes 78392 events 60 bytes each in 21.43 seconds.
Well isn't that a nice surprise? This is actually doing a little over 1 Gbps as it turns out.
For the next section its back to the drawing board to re-implement this for more speed and accuracy. We've got a rough sketch so far for what we want and this will help with the low level stuff. We also have found some interesting techniques to keep us out of the danger zone in C. Managing pointers will become all too relevant soon enough.
Aside from making more of a mess with lemmas trying to re-use some of the more vague ones in the solver the following check for number of slots was moved to the tx method which causes performance to go back to normal:
The rest of the code does check if there is enough space and it makes more sense to check this at the beginning or any stream or broadcast.