Skip to content

Instantly share code, notes, and snippets.

@wolfspider
Created October 31, 2024 15:07
Show Gist options
  • Save wolfspider/c6486b30e7a74beb87188c61ffc5f9e8 to your computer and use it in GitHub Desktop.
Save wolfspider/c6486b30e7a74beb87188c61ffc5f9e8 to your computer and use it in GitHub Desktop.
Formal Methods Ring Buffers
#include <stdio.h>
#include <stdint.h>
typedef struct t__int32_t_s
{
int32_t *b;
uint32_t *first;
uint32_t *length;
uint32_t total_length;
}
t__int32_t;
uint32_t next(uint32_t i, uint32_t total_length)
{
if (i == total_length - 1U)
return 0U;
else
return i + 1U;
}
uint32_t prev(uint32_t i, uint32_t total_length)
{
if (i > 0U)
return i - 1U;
else
return total_length - 1U;
}
uint32_t one_past_last(uint32_t i, uint32_t length, uint32_t total_length)
{
if (length == total_length)
return i;
else if (i >= total_length - length)
return length - (total_length - i);
else
return i + length;
}
// Updated push to check for a full buffer using one_past_last
void push__int32_t(t__int32_t x, int32_t e)
{
// Calculate the one past last index
uint32_t one_past_last_index = one_past_last(*x.first, *x.length, x.total_length);
if (*x.length < x.total_length) { // Not full, proceed normally
uint32_t dest_slot = prev(*x.first, x.total_length);
x.b[dest_slot] = e;
*x.first = dest_slot;
*x.length = *x.length + 1U;
} else { // Buffer is full, overwrite the oldest element
x.b[one_past_last_index] = e;
*x.first = next(*x.first, x.total_length);
}
}
int32_t pop__int32_t(t__int32_t x)
{
int32_t e = x.b[*x.first];
*x.first = next(*x.first, x.total_length);
*x.length = *x.length - 1U;
return e;
}
int32_t main(void)
{
int32_t b[3U];
for (uint32_t _i = 0U; _i < 3U; ++_i)
b[_i] = (int32_t)1;
uint32_t buf0 = 0U;
uint32_t buf = 0U;
t__int32_t rb = { .b = b, .first = &buf0, .length = &buf, .total_length = 3U };
push__int32_t(rb, (int32_t)10);
push__int32_t(rb, (int32_t)20);
push__int32_t(rb, (int32_t)30);
push__int32_t(rb, (int32_t)40); // Overwrites oldest element
int32_t r = pop__int32_t(rb);
printf("out: %d\n", r);
return r;
}
@wolfspider
Copy link
Author

So anyhow, run it through the wash again and- now we have something better.

import struct
import time
import select
import argparse
import netmap
from array import array
from typing import Optional


def build_packet() -> bytes:
    """Build a packet with pre-calculated values for better performance."""
    # Pre-calculate the packet once and reuse
    fmt = "!6s6sH46s"
    return struct.pack(
        fmt,
        b"\xff" * 6,  # Destination MAC
        b"\x00" * 6,  # Source MAC
        0x0800,  # EtherType (IPv4)
        b"\x00" * 46,  # Payload
    )


class RingBuffer:
    __slots__ = (
        "txr",
        "num_slots",
        "cur",
        "tail",
        "head",
        "cnt",
        "batch",
        "_batch_mask",
    )

    def __init__(self, txr, num_slots: int):
        """Initialize the RingBuffer with optimized attributes."""
        self.txr = txr
        self.num_slots = num_slots
        self.cur = txr.cur
        self.tail = txr.tail
        self.head = txr.head
        self.cnt = 0
        # Make batch size a power of 2 for faster modulo operations
        self.batch = 256
        self._batch_mask = self.batch - 1

    def front_load(self, pkt: bytes) -> None:
        """Pre-fill the buffer using memoryview for efficient memory operations."""
        pkt_view = memoryview(pkt)
        pkt_len = len(pkt)

        # Pre-fill all slots at once
        for slot in self.txr.slots[: self.num_slots]:
            slot.buf[0:pkt_len] = pkt_view
            slot.len = pkt_len

    def space_left(self) -> int:
        """Calculate available space using bitwise operations for better performance."""
        n = (
            (self.tail - self.cur)
            if self.tail >= self.cur
            else self.num_slots - (self.cur - self.tail)
        )
        spcn = min(self.num_slots - n, self.batch)

        # Use bitwise AND for faster modulo
        self.cur = (self.cur + spcn) & (self.num_slots - 1)
        return spcn

    def push(self) -> None:
        """Push an element using bitwise operations."""
        # Use bitwise AND for faster modulo
        self.cur = (self.cur + 1) & (self.num_slots - 1)
        self.txr.cur = self.txr.head = self.cur

    def pop(self) -> int:
        """Pop an element using bitwise operations."""
        tl = self.tail
        # Use bitwise AND for faster modulo
        self.tail = (self.tail + 1) & (self.num_slots - 1)
        self.txr.tail = self.tail
        return tl

    def sync(self, nm: netmap.Netmap) -> None:
        """Sync the transmit ring."""
        nm.txsync()


def setup_netmap(interface: str) -> tuple[netmap.Netmap, int]:
    """Setup netmap interface with proper error handling."""
    nm = netmap.Netmap()
    try:
        nm.open()
        nm.if_name = interface
        nm.register()
        # Allow interface to initialize
        time.sleep(0.1)  # Reduced from 1s to 0.1s as that should be sufficient
        return nm, nm.getfd()
    except Exception as e:
        nm.close()
        raise RuntimeError(f"Failed to setup netmap interface: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="High-performance packet generator using netmap API",
        epilog="Press Ctrl-C to stop",
    )
    parser.add_argument(
        "-i", "--interface", default="vale0:0", help="Interface to register with netmap"
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        type=int,
        default=256,
        help="Batch size for packet transmission (power of 2)",
    )
    args = parser.parse_args()

    # Ensure batch size is a power of 2
    batch_size = args.batch_size
    if batch_size & (batch_size - 1) != 0:
        batch_size = 1 << (batch_size - 1).bit_length()
        print(f"Adjusting batch size to nearest power of 2: {batch_size}")

    pkt = build_packet()
    print(f"Opening interface {args.interface}")

    try:
        nm, nfd = setup_netmap(args.interface)
        txr = nm.transmit_rings[0]
        num_slots = txr.num_slots

        # Initialize and pre-fill ring buffer
        ring_buffer = RingBuffer(txr, num_slots)
        ring_buffer.batch = batch_size
        ring_buffer.front_load(pkt)

        print("Starting transmission, press Ctrl-C to stop")

        # Use an efficient polling mechanism
        poller = select.poll()
        poller.register(nfd, select.POLLOUT)

        cnt = 0
        t_start = time.monotonic()  # More precise than time.time()

        while True:
            if not poller.poll(2):
                print("Timeout occurred")
                break

            n = ring_buffer.space_left()
            ring_buffer.push()
            ring_buffer.sync(nm)
            cnt += n

    except KeyboardInterrupt:
        print("\nTransmission interrupted by user")
    except Exception as e:
        print(f"\nError during transmission: {e}")
    finally:
        t_end = time.monotonic()
        duration = t_end - t_start

        # Calculate rates
        rate = cnt / (duration * 1000)  # Convert to thousands
        unit = "K"
        if rate > 1000:
            rate /= 1000
            unit = "M"

        print(f"\nPackets sent: [{cnt:,}], Duration: {duration:.2f}s")
        print(f"Average rate: [{rate:,.3f}] {unit}pps")

        nm.close()


if __name__ == "__main__":
    main()

@wolfspider
Copy link
Author

wolfspider commented Nov 28, 2024

This still isn't ring buffer-y enough and after working with Netmap even more I think we may need to review some more examples to get the Python code into even better shape. The concern here is that in a real setting packets will be generated on the fly and even though we have our methods defined it is still just pushing the buffer through. I have a hazy idea about seeing something more like this in the examples somewhere so we will have to go searching through them to find something adequate.

@wolfspider
Copy link
Author

wolfspider commented Nov 28, 2024

Alright after going back through it we have something formally verified, speed is back up to where it was before, and packets are being generated ad-hoc.

➜  python git:(master) ✗ python3.10 tx3.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Transmission interrupted by user

Packets sent: [159,631,103], Duration: 4.77s
Average rate: [33.465] Mpps

Comparison with the old code:

➜  python git:(master) ✗ python3.10 tx.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Packets sent: 160381952, Avg rate 34.695 Mpps

This should be the final final product in terms of the Python code:

import struct
import time
import select
import argparse
import netmap


def build_packet() -> bytes:
    """Build a packet with pre-calculated values for better performance."""
    # Pre-calculate the packet once and reuse
    fmt = "!6s6sH46s"
    return struct.pack(
        fmt,
        b"\xff" * 6,  # Destination MAC
        b"\x00" * 6,  # Source MAC
        0x0800,  # EtherType (IPv4)
        b"\x00" * 46,  # Payload
    )


class RingBuffer:
    __slots__ = (
        "txr",
        "num_slots",
        "cur",
        "tail",
        "head",
        "cnt",
        "length",
        "first",
        "batch",
    )

    def __init__(self, txr, num_slots: int):
        """Initialize the RingBuffer with optimized attributes."""
        self.txr = txr
        self.num_slots = num_slots
        self.cur = txr.cur
        self.tail = txr.tail
        self.head = txr.head
        self.cnt = 0
        self.length = 0
        self.first = 0
        self.batch = 256

    def init(self, pkt: bytes) -> None:
        """
        Pre-fill the buffer by repeatedly calling `push_cont`.
        Stops when all slots are filled.
        """
        pkt_view = memoryview(pkt)

        # Call `fpush_cont` to fill the buffer until it is full
        while self.length < self.num_slots:
            self.push_cont(pkt_view)

    def next(self, i):
        """Get the next index in a circular manner."""
        if i == self.num_slots - 1:
            return 0
        else:
            return i + 1

    def prev(self, i):
        """Get the previous index in a circular manner."""
        if i > 0:
            return i - 1
        else:
            return self.num_slots - 1

    def one_past_last(self):
        """Get the index one past the last element."""
        if self.length == self.num_slots:
            return self.first
        elif self.first >= self.num_slots - self.length:
            return self.length - (self.num_slots - self.first)
        else:
            return self.first + self.length

    def space_left(self) -> int:
        """Calculate available space using bitwise operations for better performance."""
        if self.tail >= self.cur:
            n = self.tail - self.cur
        else:
            n = self.num_slots - (self.cur - self.tail)

        spcn = min(self.num_slots - n, self.batch)

        # Update self.cur to reflect reserved space
        self.cur += spcn
        if self.cur >= self.num_slots:
            self.cur -= self.num_slots

        return spcn

    def transmit(self) -> None:
        self.txr.cur = self.txr.head = self.cur

    def push(self, e):
        """Push an element to the start of the buffer."""
        dest_slot = self.prev(self.first)
        self.txr.slots[dest_slot].buf[: len(e)] = e
        self.txr.slots[dest_slot].len = len(e)
        self.first = dest_slot
        self.length = min(self.length + 1, self.num_slots)

    def push_end(self, e):
        """Push an element to the end of the buffer."""
        dest_slot = self.one_past_last()
        self.txr.slots[dest_slot].buf[: len(e)] = e
        self.txr.slots[dest_slot].len = len(e)
        self.first = self.next(self.first)

    def push_cont(self, e):
        """Push element `e` with wraparound."""
        if self.length < self.num_slots:
            self.push(e)
        else:
            self.push_end(e)

    def pop(self):
        """Pop an element from the start of the buffer."""
        if self.length == 0:
            raise IndexError("Pop from empty buffer")
        src_slot = self.txr.slots[self.first]
        pkt = bytes(src_slot.buf[: src_slot.len])
        self.first = self.next(self.first)
        self.length -= 1
        return pkt

    def sync(self, nm: netmap.Netmap) -> None:
        """Sync the transmit ring."""
        nm.txsync()


def setup_netmap(interface: str) -> tuple[netmap.Netmap, int]:
    """Setup netmap interface with proper error handling."""
    nm = netmap.Netmap()
    try:
        nm.open()
        nm.if_name = interface
        nm.register()
        # Allow interface to initialize
        time.sleep(0.1)  # Reduced from 1s to 0.1s as that should be sufficient
        return nm, nm.getfd()
    except Exception as e:
        nm.close()
        raise RuntimeError(f"Failed to setup netmap interface: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="High-performance packet generator using netmap API",
        epilog="Press Ctrl-C to stop",
    )
    parser.add_argument(
        "-i", "--interface", default="vale0:0", help="Interface to register with netmap"
    )
    args = parser.parse_args()

    pkt = build_packet()
    print(f"Opening interface {args.interface}")

    try:
        nm, nfd = setup_netmap(args.interface)
        txr = nm.transmit_rings[0]
        num_slots = txr.num_slots

        # Initialize and pre-fill ring buffer
        ring_buffer = RingBuffer(txr, num_slots)
        ring_buffer.init(pkt)

        print("Starting transmission, press Ctrl-C to stop")

        # Use an efficient polling mechanism
        poller = select.poll()
        poller.register(nfd, select.POLLOUT)

        cnt = 0
        t_start = time.monotonic()  # More precise than time.time()

        while True:
            if not poller.poll(2):
                print("Timeout occurred")
                break

            n = ring_buffer.space_left()
            ring_buffer.transmit()
            ring_buffer.sync(nm)
            cnt += n

    except KeyboardInterrupt:
        print("\nTransmission interrupted by user")
    except Exception as e:
        print(f"\nError during transmission: {e}")
    finally:
        t_end = time.monotonic()
        duration = t_end - t_start

        # Calculate rates
        rate = cnt / (duration * 1000)  # Convert to thousands
        unit = "K"
        if rate > 1000:
            rate /= 1000
            unit = "M"

        print(f"\nPackets sent: [{cnt:,}], Duration: {duration:.2f}s")
        print(f"Average rate: [{rate:,.3f}] {unit}pps")

        nm.close()


if __name__ == "__main__":
    main()

@wolfspider
Copy link
Author

The benefit to this approach is that this is easier to update with just sending an array of arbitrary bytes where a higher bound may exist with the number of slots before sending the payload. That functionality could be easily dropped in here. Now we can finally move on to a part 2 remaking this at a lower level.

A couple of edits to the code also brought speed up to this:

python git:(master) ✗ python3.10 tx3.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Transmission interrupted by user

Packets sent: [201,336,063], Duration: 5.84s
Average rate: [34.456] Mpps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment