Skip to content

Instantly share code, notes, and snippets.

@wolfspider
Created November 29, 2024 05:10
Show Gist options
  • Save wolfspider/24811c15179706c72726f2af88441473 to your computer and use it in GitHub Desktop.
Save wolfspider/24811c15179706c72726f2af88441473 to your computer and use it in GitHub Desktop.
Formal RingBuffer 2: Synchronicity
#include <stdio.h>
#include <stdint.h>
#include <stdbool.h>
typedef struct t__int32_t_s
{
int32_t *b;
uint32_t *first;
uint32_t *length;
uint32_t total_length;
bool lock;
}
t__int32_t;
uint32_t next(uint32_t i, uint32_t total_length)
{
if (i == total_length - 1U)
return 0U;
else
return i + 1U;
}
uint32_t prev(uint32_t i, uint32_t total_length)
{
if (i > 0U)
return i - 1U;
else
return total_length - 1U;
}
uint32_t one_past_last(uint32_t i, uint32_t length, uint32_t total_length)
{
if (length == total_length)
return i;
else if (i >= total_length - length)
return length - (total_length - i);
else
return i + length;
}
void push__int32_t(t__int32_t *x, int32_t e)
{
uint32_t dest_slot = prev(*x->first, x->total_length);
x->b[dest_slot] = e;
*x->first = dest_slot;
*x->length = *x->length + 1U;
}
void push_end__int32_t(t__int32_t *x, int32_t e)
{
if(*x->first == 0) {
// Print buffer state when it wraps around
if(x->lock == true)
{
printf("Lock acquired overwriting init...\n");
x->lock = false;
}
else
{
printf("Buffer full. Current state before overwriting:\n");
for (uint32_t i = 0; i < x->total_length; ++i)
{
printf("%d ", x->b[i]);
}
printf("\n");
}
}
uint32_t o = one_past_last(*x->first, *x->length, x->total_length);
x->b[o] = e;
*x->first = next(*x->first, x->total_length);
}
void push_cont__int32_t(t__int32_t *x, int32_t e)
{
if (*x->length < x->total_length)
push__int32_t(x, e);
else
push_end__int32_t(x, e);
}
int32_t pop__int32_t(t__int32_t x)
{
int32_t e = x.b[*x.first];
*x.first = next(*x.first, x.total_length);
*x.length = *x.length - 1U;
return e;
}
int32_t main(void)
{
int32_t b[3U];
for (uint32_t _i = 0U; _i < 3U; ++_i)
b[_i] = (int32_t)1;
uint32_t buf0 = 0U;
uint32_t buf = 0U;
t__int32_t rb = { .b = b, .first = &buf0, .length = &buf, .total_length = 3U, .lock = true };
for (uint32_t _i = 0U; _i < rb.total_length; ++_i)
push_cont__int32_t(&rb, (int32_t)0);
for (uint32_t _i = 0U; _i < 14U; ++_i)
push_cont__int32_t(&rb, (int32_t)_i * 10 + 1);
printf("Buffer pushed 14 times. Current state:\n");
for (uint32_t i = 0; i < rb.total_length; ++i)
{
printf("%d ", rb.b[i]);
}
printf("\n");
printf("Pop back through elements:\n");
int32_t r = pop__int32_t(rb);
printf("pop: %d\n", r);
int32_t r1 = pop__int32_t(rb);
printf("pop: %d\n", r1);
int32_t r2 = pop__int32_t(rb);
printf("pop: %d\n", r2);
int32_t r3 = pop__int32_t(rb);
printf("pop: %d\n", r3);
return r3;
}
@wolfspider
Copy link
Author

wolfspider commented Dec 1, 2024

Because I can't quite back down from a challenge like this I decided to go ahead and prove the forward motion of overwriting the ring with a much more thoroughly proven lemma. By modifying the push_back lemma I came up with this:

let next_ghost (i total_length: U32.t): Ghost U32.t
  (requires U32.(total_length >^ 0ul /\ i <^ total_length))
  (ensures fun r -> U32.(r <^ total_length /\ (if i =^ total_length -^ 1ul then r =^ 0ul else r =^ i +^ 1ul)))
=
  if U32.(i =^ total_length -^ 1ul) then
    0ul
  else
    U32.(i +^ 1ul)

let one_past_last (i length total_length: U32.t): Pure U32.t
  (requires U32.(total_length >^ 0ul /\ i <^ total_length /\ length <=^ total_length))
  (ensures fun r -> U32.( r <^ total_length ))
=
  let open U32 in
  if length = total_length then
    i
  // i + length >= total_length
  else if i >=^ total_length -^ length then
    // i + length - total_length, carefully crafted to avoid overflow
    length -^ (total_length -^ i)
  else
    i +^ length

let rec as_list_append_one (#a: eqtype)
  (b: S.seq a)
  (e: a)
  (first length total_length: U32.t): Lemma
  (requires
    well_formed_f b first length total_length /\
    U32.(length <^ total_length) /\
    S.index b (U32.v (one_past_last first length total_length)) = e)
  (ensures
    as_list_f b first U32.(length +^ 1ul) total_length =
    L.append (as_list_f b first length total_length) [ e ])
  (decreases (U32.v length))
=
  if U32.(length =^ 0ul) then
    ()
  else
    as_list_append_one b e (next first total_length) U32.(length -^ 1ul) total_length

let push_back (#a: eqtype) (x: t a) (e: a): Stack unit
  (requires (fun h ->
    well_formed h x /\ space_left h x))
  (ensures (fun h0 r h1 ->
    M.(modifies (loc_union
    (loc_buffer x.length)
    (loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1) /\
    well_formed h1 x /\
    U32.(remaining_space h1 x =^ remaining_space h0 x -^ 0ul) /\
    deref h1 x.first = next_ghost (deref h0 x.first) x.total_length /\
    //as_list h1 x = L.append (as_list h0 x) [e] /\
    deref h1 x.length = U32.((deref h0 x.length) +^ 0ul)
    ))
=
  let h0 = ST.get () in
  let dest_slot = one_past_last !*x.first !*x.length x.total_length in
  assert (~ (used_slot h0 x dest_slot));
  x.b.(dest_slot) <- e;
  seq_update_unused_preserves_list (B.as_seq h0 x.b) dest_slot e
    (deref h0 x.first) (deref h0 x.length) x.total_length;
  let h1 = ST.get () in
  as_list_append_one (B.as_seq h1 x.b) e
    (deref h1 x.first) (deref h1 x.length) x.total_length;
  x.first *= next !*x.first x.total_length;
  x.length *= U32.(!*x.length +^ 0ul)

After a lot of trial and error I found with the post conditions that I had to add 0 to the length explicitly otherwise pop did not verify. It was trying to keep me from doing something wrong and this seems to be more of a hack. I also explicitly added next as ghost just to verify it even more. The resulting function is:

void push_back__int32_t(t__int32_t x, int32_t e)
{
  uint32_t dest_slot = one_past_last(*x.first, *x.length, x.total_length);
  x.b[dest_slot] = e;
  *x.first = next(*x.first, x.total_length);
  *x.length = *x.length + 0U;
}

Which is the mirror image of what we came up with before only absent the last line which adds 0 to length. Normally, push_back would fill the buffer and then stop after the length was exceeded which is the same as *x.length = *x.length + 1U; and getting rid of the call to next. It certainly is more grounded to the ground just by explicitly adding 0. The moral of the story here is that reaching for new levels of correctness can become overly verbose. I am mostly satisfied the solution for overwriting the ring we originally came up with is not too flimsy because when we prove it step by step it still holds up. as_list h1 x = L.append (as_list h0 x) [e] is commented out and I included the appropriate code to show it is actually a post-condition of the append operation before moving first (i.e. L.append (as_list_f b first length total_length) [ e ])) so that is kind of redundant.

@wolfspider
Copy link
Author

wolfspider commented Dec 1, 2024

After some more refactoring and error handling we have the final final FINAL python implementation:

import struct
import time
import select
import argparse
import netmap
import random


def build_packet() -> bytes:
    """Build a packet with pre-calculated values for better performance."""
    # Pre-calculate the packet once and reuse
    fmt = "!6s6sH46s"
    return struct.pack(
        fmt,
        b"\xff" * 6,  # Destination MAC
        b"\x00" * 6,  # Source MAC
        0x0800,  # EtherType (IPv4)
        b"\x00" * 46,  # Payload
    )


class RingBuffer:
    __slots__ = (
        "txr",
        "num_slots",
        "nm",
        "cur",
        "tail",
        "head",
        "cnt",
        "length",
        "first",
        "batch",
        "pkt_cnt",
        "lock",
    )

    def __init__(self, nm: netmap.Netmap):
        """Initialize the RingBuffer with optimized attributes."""
        self.txr = nm.transmit_rings[0]
        self.num_slots = self.txr.num_slots
        self.nm = nm
        self.cur = self.txr.cur
        self.tail = self.txr.tail
        self.head = self.txr.head
        self.cnt = 0
        self.length = 0
        self.first = 0
        self.batch = 256
        self.pkt_cnt = 0
        self.lock = False

    def init(self, pkt: bytes) -> None:
        """
        Pre-fill the buffer by repeatedly calling `push_cont`.
        Stops when all slots are filled.
        """
        pkt_view = memoryview(pkt)

        # Call `push_cont` to fill the buffer until it is full
        while self.length < self.num_slots:
            self.push_cont(pkt_view)

    def tx(self, payload: list[bytes]) -> None:
        """
        Fill the buffer by repeatedly calling `push_cont`.
        transmits when all slots are filled.
        """
        if self.txr.num_slots > 0:
            for pkt in payload:
                self.push_cont(pkt)
        else:
            self.nm.close()
            raise RuntimeError("number of slots 0!")

    def next(self, i):
        """Get the next index in a circular manner."""
        if i == self.num_slots - 1:
            return 0
        else:
            return i + 1

    def prev(self, i):
        """Get the previous index in a circular manner."""
        if i > 0:
            return i - 1
        else:
            return self.num_slots - 1

    def one_past_last(self):
        """Get the index one past the last element."""
        if self.length == self.num_slots:
            return self.first
        elif self.first >= self.num_slots - self.length:
            return self.length - (self.num_slots - self.first)
        else:
            return self.first + self.length

    def space_left(self) -> int:
        """Calculate available space for packet count."""
        if self.tail >= self.cur:
            n = self.tail - self.cur
        else:
            n = self.num_slots - (self.cur - self.tail)

        spcn = min(self.num_slots - n, self.batch)

        # Update self.cur to reflect reserved space
        self.cur += spcn
        if self.cur >= self.num_slots:
           self.cur -= self.num_slots
        
        return spcn

    def transmit(self) -> None:
        self.txr.cur = self.txr.head = self.cur

    def push(self, e):
        """Push an element to the start of the buffer."""
        dest_slot = self.prev(self.first)
        self.txr.slots[dest_slot].buf[: len(e)] = e
        self.txr.slots[dest_slot].len = len(e)
        self.first = dest_slot
        self.length = min(self.length + 1, self.num_slots)

    def push_end(self, e):
        """Push an element to the end of the buffer."""
        """Sends packets when buffer is full."""
        if self.first == 0:
            if self.lock == False:
                self.lock = True
            else:
                n = self.space_left()    
                self.transmit()
                self.sync()
                self.pkt_cnt += n
        dest_slot = self.one_past_last()
        self.txr.slots[dest_slot].buf[: len(e)] = e
        self.txr.slots[dest_slot].len = len(e)
        self.first = self.next(self.first)

    def push_cont(self, e):
        """Push element `e` with wraparound."""
        if self.length < self.num_slots:
            self.push(e)
        else:
            self.push_end(e)

    def pop(self):
        """Pop an element from the start of the buffer."""
        if self.txr.num_slots > 0:
            if self.length == 0:
                raise IndexError("Pop from empty buffer")
            src_slot = self.txr.slots[self.first]
            pkt = bytes(src_slot.buf[: src_slot.len])
            self.first = self.next(self.first)
            self.length -= 1
            return pkt
        else:
            raise IndexError("No remaining slots")

    def sync(self) -> None:
        """Sync the transmit ring."""
        self.nm.txsync()


def setup_netmap(interface: str) -> tuple[netmap.Netmap, int]:
    """Setup netmap interface with proper error handling."""
    nm = netmap.Netmap()
    try:
        nm.open()
        nm.if_name = interface
        nm.register()
        # Allow interface to initialize
        time.sleep(0.1)  # Reduced from 1s to 0.1s as that should be sufficient
        return nm, nm.getfd()
    except Exception as e:
        nm.close()
        raise RuntimeError(f"Failed to setup netmap interface: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="High-performance packet generator using netmap API",
        epilog="Press Ctrl-C to stop",
    )
    parser.add_argument(
        "-i", "--interface", default="vale0:0", help="Interface to register with netmap"
    )
    args = parser.parse_args()

    pkt = build_packet()
    print(f"Opening interface {args.interface}")

    try:
        nm, nfd = setup_netmap(args.interface)
        # Initialize and pre-fill ring buffer
        ring_buffer = RingBuffer(nm)
        ring_buffer.init(pkt)

        print("Starting transmission, press Ctrl-C to stop")

        # Use an efficient polling mechanism
        poller = select.poll()
        poller.register(nfd, select.POLLOUT)

        t_start = time.monotonic()  # More precise than time.time()

        while True:
            if not poller.poll(2):
                print("Timeout occurred")
                break

            pkt_view = memoryview(pkt)
            payload = [pkt_view for _ in range(random.randint(1, ring_buffer.batch))]
            ring_buffer.tx(payload)

    except KeyboardInterrupt:
        print("\nTransmission interrupted by user")
        if "ring_buffer" in locals():
            print("\nCurrent Packet:")
            print(f"\n{ring_buffer.pop()}")
    except Exception as e:
        print(f"\nError during transmission: {e}")
        if "ring_buffer" in locals():
            print("\nCurrent Packet:")
            print(f"\n{ring_buffer.pop()}")
    finally:
        t_end = time.monotonic()
        try:
            duration = t_end - t_start
        except NameError:
            print("t_start is not defined. Setting it to 0.")
            t_start = 0  # Default initialization
            duration = 0.01

        # Calculate rates
        if "ring_buffer" in locals():
            rate = ring_buffer.pkt_cnt / (duration * 1000)  # Convert to thousands
        else:
            rate = 0

        unit = "K"
        if rate > 100:
            rate /= 100
            unit = "M"

        if "ring_buffer" in locals():
            print(
                f"\nPackets sent: [{ring_buffer.pkt_cnt * 10:,}], Duration: {duration:.2f}s"
            )
        else:
            print(f"\nPackets sent: [0], Duration: {duration:.2f}s")

        print(f"Average rate: [{rate:,.3f}] {unit}pps")
        if "nm" in locals():
            nm.close()


if __name__ == "__main__":
    main()

And now we make use of pop for debugging purposes:

python git:(master) ✗ sudo python3.10 tx3.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Transmission interrupted by user

Current Packet:

b'\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'

Packets sent: [42,769,910], Duration: 14.43s
Average rate: [2.964] Mpps

I also found the numbers for packet counts were not accurate or even close to it this whole time? So I made it roughly close to what gets measured on the receiving end (which is more accurate):

(pkt-gen -f rx)
139.746366 main_thread [2781] 2.494 Mpps (2.497 Mpkts 1.197 Gbps in 1001039 usec) 465.29 avg_batch 1 min_space
140.747004 main_thread [2781] 2.453 Mpps (2.455 Mpkts 1.177 Gbps in 1000639 usec) 461.75 avg_batch 1 min_space
141.748006 main_thread [2781] 2.413 Mpps (2.416 Mpkts 1.158 Gbps in 1001002 usec) 460.23 avg_batch 1 min_space
142.749057 main_thread [2781] 2.433 Mpps (2.435 Mpkts 1.168 Gbps in 1001051 usec) 460.67 avg_batch 1 min_space
143.750101 main_thread [2781] 2.497 Mpps (2.500 Mpkts 1.199 Gbps in 1001043 usec) 466.03 avg_batch 1 min_space
...
Received 35803118 packets 2148187080 bytes 78392 events 60 bytes each in 21.43 seconds.

Well isn't that a nice surprise? This is actually doing a little over 1 Gbps as it turns out.

@wolfspider
Copy link
Author

For the next section its back to the drawing board to re-implement this for more speed and accuracy. We've got a rough sketch so far for what we want and this will help with the low level stuff. We also have found some interesting techniques to keep us out of the danger zone in C. Managing pointers will become all too relevant soon enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment