Skip to content

Instantly share code, notes, and snippets.

@xen0n
Created July 21, 2023 12:57

Revisions

  1. xen0n created this gist Jul 21, 2023.
    409 changes: 409 additions & 0 deletions linux-xor-simd-test.c
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,409 @@
    // SPDX-License-Identifier: GPL-2.0-or-later

    /*
    $ gcc -O3 -o linux-xor-simd-test linux-xor-simd-test.c
    $ ./linux-xor-simd-test
    ref (size=4096 ) passed 16383 times: 0.005044150 s total, 0.000000307 s per pass, 12687.191 MiB/s
    lsx_32b (size=4096 ) passed 16383 times: 0.002663250 s total, 0.000000162 s per pass, 24029.323 MiB/s
    lsx_64b (size=4096 ) passed 16383 times: 0.002517970 s total, 0.000000153 s per pass, 25415.749 MiB/s
    lsx_128b (size=4096 ) passed 16383 times: 0.002517590 s total, 0.000000153 s per pass, 25419.585 MiB/s
    lasx_32b (size=4096 ) passed 16383 times: 0.001935550 s total, 0.000000118 s per pass, 33063.519 MiB/s
    lasx_64b (size=4096 ) passed 16383 times: 0.001813990 s total, 0.000000110 s per pass, 35279.188 MiB/s
    lasx_128b (size=4096 ) passed 16383 times: 0.001756910 s total, 0.000000107 s per pass, 36425.368 MiB/s
    */

    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    #include <time.h>
    #include <sys/random.h>

    // may not be widely available yet
    // #include <lsxintrin.h>

    // #define DATA_SIZE_MIN_ORDER 9 // 512
    // #define DATA_SIZE_MAX_ORDER 20 // 1MiB

    // same as crypto/xor.c
    #define DATA_SIZE_MIN_ORDER 12 // 4KiB
    #define DATA_SIZE_MAX_ORDER 12 // 4KiB

    #define TIMES 16383 // must be odd

    typedef void (*xor_impl_t)(void * __restrict, const void * __restrict, size_t);

    // taken from linux include/asm-generic/xor.h
    static void
    xor_8regs_2(unsigned long bytes, unsigned long * __restrict p1,
    const unsigned long * __restrict p2)
    {
    long lines = bytes / (sizeof (long)) / 8;

    do {
    p1[0] ^= p2[0];
    p1[1] ^= p2[1];
    p1[2] ^= p2[2];
    p1[3] ^= p2[3];
    p1[4] ^= p2[4];
    p1[5] ^= p2[5];
    p1[6] ^= p2[6];
    p1[7] ^= p2[7];
    p1 += 8;
    p2 += 8;
    } while (--lines > 0);
    }

    static void reference_xor(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_8regs_2(len, a, b);
    }

    //
    // LSX
    //

    static void
    xor_lsx_32b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
    {
    long lines = bytes / 32;
    do {
    asm volatile (
    "vld $vr0, %[dst], 0\n\t"
    "vld $vr1, %[dst], 16\n\t"

    "vld $vr2, %[src], 0\n\t"
    "vld $vr3, %[src], 16\n\t"

    "vxor.v $vr0, $vr0, $vr2\n\t"
    "vxor.v $vr1, $vr1, $vr3\n\t"

    "vst $vr0, %[dst], 0\n\t"
    "vst $vr1, %[dst], 16\n\t"
    : : [dst] "r"(p1), [src] "r"(p2)
    : "memory"
    );
    p1 += 32;
    p2 += 32;
    } while (--lines > 0);
    }

    static void lsx_32b_glue(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_lsx_32b(len, a, b);
    }

    static void
    xor_lsx_64b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
    {
    long lines = bytes / 64;
    do {
    asm volatile (
    "vld $vr0, %[dst], 0\n\t"
    "vld $vr1, %[dst], 16\n\t"
    "vld $vr2, %[dst], 32\n\t"
    "vld $vr3, %[dst], 48\n\t"

    "vld $vr4, %[src], 0\n\t"
    "vld $vr5, %[src], 16\n\t"
    "vld $vr6, %[src], 32\n\t"
    "vld $vr7, %[src], 48\n\t"

    "vxor.v $vr0, $vr0, $vr4\n\t"
    "vxor.v $vr1, $vr1, $vr5\n\t"
    "vxor.v $vr2, $vr2, $vr6\n\t"
    "vxor.v $vr3, $vr3, $vr7\n\t"

    "vst $vr0, %[dst], 0\n\t"
    "vst $vr1, %[dst], 16\n\t"
    "vst $vr2, %[dst], 32\n\t"
    "vst $vr3, %[dst], 48\n\t"
    : : [dst] "r"(p1), [src] "r"(p2)
    : "memory"
    );
    p1 += 64;
    p2 += 64;
    } while (--lines > 0);
    }

    static void lsx_64b_glue(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_lsx_64b(len, a, b);
    }

    static void
    xor_lsx_128b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
    {
    long lines = bytes / 128;
    do {
    asm volatile (
    "vld $vr0, %[dst], 0\n\t"
    "vld $vr1, %[dst], 16\n\t"
    "vld $vr2, %[dst], 32\n\t"
    "vld $vr3, %[dst], 48\n\t"
    "vld $vr4, %[dst], 64\n\t"
    "vld $vr5, %[dst], 80\n\t"
    "vld $vr6, %[dst], 96\n\t"
    "vld $vr7, %[dst], 112\n\t"

    "vld $vr8, %[src], 0\n\t"
    "vld $vr9, %[src], 16\n\t"
    "vld $vr10, %[src], 32\n\t"
    "vld $vr11, %[src], 48\n\t"
    "vld $vr12, %[src], 64\n\t"
    "vld $vr13, %[src], 80\n\t"
    "vld $vr14, %[src], 96\n\t"
    "vld $vr15, %[src], 112\n\t"

    "vxor.v $vr0, $vr0, $vr8\n\t"
    "vxor.v $vr1, $vr1, $vr9\n\t"
    "vxor.v $vr2, $vr2, $vr10\n\t"
    "vxor.v $vr3, $vr3, $vr11\n\t"
    "vxor.v $vr4, $vr4, $vr12\n\t"
    "vxor.v $vr5, $vr5, $vr13\n\t"
    "vxor.v $vr6, $vr6, $vr14\n\t"
    "vxor.v $vr7, $vr7, $vr15\n\t"

    "vst $vr0, %[dst], 0\n\t"
    "vst $vr1, %[dst], 16\n\t"
    "vst $vr2, %[dst], 32\n\t"
    "vst $vr3, %[dst], 48\n\t"
    "vst $vr4, %[dst], 64\n\t"
    "vst $vr5, %[dst], 80\n\t"
    "vst $vr6, %[dst], 96\n\t"
    "vst $vr7, %[dst], 112\n\t"
    : : [dst] "r"(p1), [src] "r"(p2)
    : "memory"
    );
    p1 += 128;
    p2 += 128;
    } while (--lines > 0);
    }

    static void lsx_128b_glue(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_lsx_128b(len, a, b);
    }

    //
    // LASX
    //

    static void
    xor_lasx_32b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
    {
    long lines = bytes / 32;
    do {
    asm volatile (
    "xvld $xr0, %[dst], 0\n\t"

    "xvld $xr1, %[src], 0\n\t"

    "xvxor.v $xr0, $xr0, $xr1\n\t"

    "xvst $xr0, %[dst], 0\n\t"
    : : [dst] "r"(p1), [src] "r"(p2)
    : "memory"
    );
    p1 += 32;
    p2 += 32;
    } while (--lines > 0);
    }

    static void lasx_32b_glue(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_lasx_32b(len, a, b);
    }

    static void
    xor_lasx_64b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
    {
    long lines = bytes / 64;
    do {
    asm volatile (
    "xvld $xr0, %[dst], 0\n\t"
    "xvld $xr1, %[dst], 32\n\t"

    "xvld $xr2, %[src], 0\n\t"
    "xvld $xr3, %[src], 32\n\t"

    "xvxor.v $xr0, $xr0, $xr2\n\t"
    "xvxor.v $xr1, $xr1, $xr3\n\t"

    "xvst $xr0, %[dst], 0\n\t"
    "xvst $xr1, %[dst], 32\n\t"
    : : [dst] "r"(p1), [src] "r"(p2)
    : "memory"
    );
    p1 += 64;
    p2 += 64;
    } while (--lines > 0);
    }

    static void lasx_64b_glue(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_lasx_64b(len, a, b);
    }

    static void
    xor_lasx_128b(unsigned long bytes, void * __restrict p1, const void * __restrict p2)
    {
    long lines = bytes / 128;
    do {
    asm volatile (
    "xvld $xr0, %[dst], 0\n\t"
    "xvld $xr1, %[dst], 32\n\t"
    "xvld $xr2, %[dst], 64\n\t"
    "xvld $xr3, %[dst], 96\n\t"

    "xvld $xr4, %[src], 0\n\t"
    "xvld $xr5, %[src], 32\n\t"
    "xvld $xr6, %[src], 64\n\t"
    "xvld $xr7, %[src], 96\n\t"

    "xvxor.v $xr0, $xr0, $xr4\n\t"
    "xvxor.v $xr1, $xr1, $xr5\n\t"
    "xvxor.v $xr2, $xr2, $xr6\n\t"
    "xvxor.v $xr3, $xr3, $xr7\n\t"

    "xvst $xr0, %[dst], 0\n\t"
    "xvst $xr1, %[dst], 32\n\t"
    "xvst $xr2, %[dst], 64\n\t"
    "xvst $xr3, %[dst], 96\n\t"
    : : [dst] "r"(p1), [src] "r"(p2)
    : "memory"
    );
    p1 += 128;
    p2 += 128;
    } while (--lines > 0);
    }

    static void lasx_128b_glue(void * __restrict a, const void * __restrict b, size_t len)
    {
    xor_lasx_128b(len, a, b);
    }

    //
    // helpers
    //

    static void must_fill_randomness(void *buf, size_t len)
    {
    ssize_t ret;
    void *p = buf;
    while (len) {
    ret = getrandom(p, len, 0);
    if (ret < 0)
    abort();
    p += ret;
    len -= ret;
    }
    }

    static struct timespec diff_timespec(
    const struct timespec *time1,
    const struct timespec *time0)
    {
    struct timespec diff = {
    .tv_sec = time1->tv_sec - time0->tv_sec,
    .tv_nsec = time1->tv_nsec - time0->tv_nsec
    };
    if (diff.tv_nsec < 0) {
    diff.tv_nsec += 1000000000; // nsec/sec
    diff.tv_sec--;
    }
    return diff;
    }

    static struct timespec div_timespec(struct timespec x, int denom)
    {
    // assume the value is not very large
    long s = x.tv_sec * 1000000000 + x.tv_nsec;
    s /= denom;
    struct timespec ret = {
    .tv_sec = s / 1000000000,
    .tv_nsec = s % 1000000000,
    };
    return ret;
    }

    static double get_throughput(int size, struct timespec elapsed, int times)
    {
    double secs = (double)(elapsed.tv_sec * 1000000000l + (long)(elapsed.tv_nsec)) / 1e9;
    double total_size = (double)((long)size * (long)times);
    return total_size / secs;
    }

    static int run_order(int order, const char *desc, xor_impl_t fn)
    {
    void *a, *b, *ref;
    int size = 1 << order;
    struct timespec start, end, elapsed, pass_time;
    int i, ret;

    if (!(a = malloc(size)))
    abort();
    if (!(b = malloc(size)))
    abort();
    if (!(ref = malloc(size)))
    abort();

    must_fill_randomness(a, size);
    must_fill_randomness(b, size);
    memcpy(ref, a, size);
    reference_xor(ref, b, size);

    {
    if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &start))
    abort();

    for (i = 0; i < TIMES; i++)
    fn(a, b, size);

    if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &end))
    abort();
    }

    elapsed = diff_timespec(&end, &start);
    pass_time = div_timespec(elapsed, TIMES);

    ret = memcmp(a, ref, size) != 0;
    printf(
    "%-10s(size=%-7d) %s %d times: %ld.%09ld s total, %ld.%09ld s per pass, %.3lf MiB/s\n",
    desc,
    size,
    ret ? "failed" : "passed",
    TIMES,
    elapsed.tv_sec,
    elapsed.tv_nsec,
    pass_time.tv_sec,
    pass_time.tv_nsec,
    get_throughput(size, elapsed, TIMES) / 1048576.0
    );

    free(ref);
    free(b);
    free(a);
    return ret;
    }

    static int try_all_orders(const char *desc, xor_impl_t fn)
    {
    int order, ret = 0;
    for (order = DATA_SIZE_MIN_ORDER; order <= DATA_SIZE_MAX_ORDER; order++)
    ret |= run_order(order, desc, fn);
    return ret;
    }

    int main(int argc, const char *argv[])
    {
    int ret = 0;
    ret |= try_all_orders("ref", reference_xor);
    ret |= try_all_orders("lsx_32b", lsx_32b_glue);
    ret |= try_all_orders("lsx_64b", lsx_64b_glue);
    ret |= try_all_orders("lsx_128b", lsx_128b_glue);
    ret |= try_all_orders("lasx_32b", lasx_32b_glue);
    ret |= try_all_orders("lasx_64b", lasx_64b_glue);
    ret |= try_all_orders("lasx_128b", lasx_128b_glue);

    return ret;
    }