Instantly share code, notes, and snippets.
Last active
February 6, 2024 10:24
-
Star
8
(8)
You must be signed in to star a gist -
Fork
1
(1)
You must be signed in to fork a gist
-
Save kieber-emmons/7c30e2ba3e02da30bbb44baee6bada39 to your computer and use it in GitHub Desktop.
This gist is for an article I wrote on Medium (https://medium.com/p/4f4590cfd5d3).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// ParallelRadixSort.metal | |
// | |
// Created by Matthew Kieber-Emmons on 08/29/22. | |
// Copyright © 2022 Matthew Kieber-Emmons. All rights reserved. | |
// This work is for educational purposes only and cannot be used without consent. | |
// | |
#include <metal_stdlib> | |
using namespace metal; | |
//////////////////////////////////////////////////////////////// | |
// MARK: - Compilation Constants | |
// these constants are typically provided at library generation but we have sensible defaults here | |
//////////////////////////////////////////////////////////////// | |
#ifndef THREADS_PER_THREADGROUP | |
#define THREADS_PER_THREADGROUP (256) | |
#endif | |
#ifndef VALUES_PER_THREAD | |
#define VALUES_PER_THREAD (4) | |
#endif | |
#ifndef EXECUTION_WIDTH | |
#define EXECUTION_WIDTH (32) | |
#endif | |
#ifndef LIBRARY_RADIX | |
#define LIBRARY_RADIX (32) | |
#endif | |
//////////////////////////////////////////////////////////////// | |
// MARK: - Functions Constants | |
// these constants control the code paths at pipeline creation | |
//////////////////////////////////////////////////////////////// | |
constant int LOCAL_ALGORITHM [[function_constant(0)]]; | |
constant int GLOBAL_ALGORITHM [[function_constant(1)]]; | |
#define SORT_GLOBAL_ALGORITHM_LSD (0) | |
constant bool DISABLE_BOUNDS_CHECK [[function_constant(2)]]; | |
constant bool ASCENDING [[function_constant(3)]]; | |
#define SORT_DIRECTION_ASCENDING (0) | |
#define SORT_DIRECTION_DESCENDING (1) | |
constant int SORT_OPTIONS [[function_constant(4)]]; | |
//////////////////////////////////////////////////////////////// | |
// MARK: - Helpers | |
//////////////////////////////////////////////////////////////// | |
static constexpr bool IsPowerOfTwo(uint32_t x){ | |
return ((x != 0) && !(x & (x - 1))); | |
} | |
static constexpr ushort RadixToBits(ushort n) { | |
return (n-1<2)?1: | |
(n-1<4)?2: | |
(n-1<8)?3: | |
(n-1<16)?4: | |
(n-1<32)?5: | |
(n-1<64)?6: | |
(n-1<128)?7: | |
(n-1<256)?8: | |
(n-1<512)?9: | |
(n-1<1024)?10: | |
(n-1<2048)?11: | |
(n-1<4096)?12: | |
(n-1<8192)?13: | |
(n-1<16384)?14: | |
(n-1<32768)?15:0; | |
} | |
template <ushort RADIX, typename T> static inline ushort | |
ValueToKeyAtBit(T value, ushort current_bit){ | |
return (value >> current_bit) & (RADIX - 1); | |
} | |
template <ushort RADIX> static inline ushort | |
ValueToKeyAtBit(int32_t value, ushort current_bit){ | |
return ( (as_type<uint32_t>(value) ^ (1U << 31)) >> current_bit) & (RADIX - 1); | |
} | |
template <ushort RADIX, typename T> static inline ushort | |
ValueToKeyAtDigit(T value, ushort current_digit){ | |
ushort bits_to_shift = RadixToBits(RADIX) * current_digit; | |
return ValueToKeyAtBit<RADIX>(value, bits_to_shift); | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Load and Store Functions | |
/////////////////////////////////////////////////////////////////////////////// | |
// blocked read into registers i.e. ABCDEFGH -> AB, CD, EF, GH | |
template<ushort LENGTH, typename T> static void | |
LoadBlockedLocalFromGlobal(thread T (&value)[LENGTH], | |
const device T* input_data, | |
const ushort local_id) { | |
for (ushort i = 0; i < LENGTH; i++){ | |
value[i] = input_data[local_id * LENGTH + i]; | |
} | |
} | |
// blocked read into registers with bounds checking | |
template<ushort LENGTH, typename T> static void | |
LoadBlockedLocalFromGlobal(thread T (&value)[LENGTH], | |
const device T* input_data, | |
const ushort local_id, | |
const uint n, | |
const T substitution_value) { | |
for (ushort i = 0; i < LENGTH; i++){ | |
value[i] = (local_id * LENGTH + i < n) ? input_data[local_id * LENGTH + i] : substitution_value; | |
} | |
} | |
// striped read into registers i.e. ABCDEFGH -> AE, BF, CG, EH | |
template<ushort LENGTH, typename T> static void | |
LoadStripedLocalFromGlobal(thread T (&value)[LENGTH], | |
const device T* input_data, | |
const ushort local_id, | |
const ushort local_size) { | |
for (ushort i = 0; i < LENGTH; i++){ | |
value[i] = input_data[local_id + i * local_size]; | |
} | |
} | |
// striped read into registers with bounds checking | |
template<ushort LENGTH, typename T> static void | |
LoadStripedLocalFromGlobal(thread T (&value)[LENGTH], | |
const device T* input_data, | |
const ushort local_id, | |
const ushort local_size, | |
const uint n, | |
const T substitution_value){ | |
// this is a blocked read into registers | |
for (ushort i = 0; i < LENGTH; i++){ | |
value[i] = (local_id + i * local_size < n) ? input_data[local_id + i * local_size] : substitution_value; | |
} | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Prefix Scan Functions | |
/////////////////////////////////////////////////////////////////////////////// | |
template <typename T> | |
struct SumOp { | |
inline T operator()(thread const T& a, thread const T& b) const{return a + b;} | |
inline T operator()(threadgroup const T& a, thread const T& b) const{return a + b;} | |
inline T operator()(threadgroup const T& a, threadgroup const T& b) const{return a + b;} | |
inline T operator()(volatile threadgroup const T& a, volatile threadgroup const T& b) const{return a + b;} | |
constexpr T identity(){return static_cast<T>(0);} | |
}; | |
template <typename T> | |
struct MaxOp { | |
inline T operator()(thread const T& a, thread const T& b) const{return max(a,b);} | |
inline T operator()(threadgroup const T& a, thread const T& b) const{return max(a,b);} | |
inline T operator()(threadgroup const T& a, threadgroup const T& b) const{return max(a,b);} | |
inline T operator()(volatile threadgroup const T& a, volatile threadgroup const T& b) const{return max(a,b);} | |
constexpr T identity(){ return metal::numeric_limits<T>::min(); } | |
}; | |
#define SCAN_TYPE_INCLUSIVE (0) | |
#define SCAN_TYPE_EXCLUSIVE (1) | |
template<ushort LENGTH, int SCAN_TYPE, typename BinaryOp, typename T> | |
static inline T ThreadScan(threadgroup T* values, BinaryOp Op){ | |
for (ushort i = 1; i < LENGTH; i++){ | |
values[i] = Op(values[i],values[i - 1]); | |
} | |
T result = values[LENGTH - 1]; | |
if (SCAN_TYPE == SCAN_TYPE_EXCLUSIVE){ | |
for (ushort i = LENGTH - 1; i > 0; i--){ | |
values[i] = values[i - 1]; | |
} | |
values[0] = 0; | |
} | |
return result; | |
} | |
template<ushort LENGTH, typename BinaryOp, typename T> static inline void | |
ThreadUniformApply(thread T* values, T uni, BinaryOp Op){ | |
for (ushort i = 0; i < LENGTH; i++){ | |
values[i] = Op(values[i],uni); | |
} | |
} | |
template<ushort LENGTH, typename BinaryOp, typename T> static inline void | |
ThreadUniformApply(threadgroup T* values, T uni, BinaryOp Op){ | |
for (ushort i = 0; i < LENGTH; i++){ | |
values[i] = Op(values[i],uni); | |
} | |
} | |
template <int SCAN_TYPE, typename BinaryOp, typename T> static inline T | |
SimdgroupScan(T value, ushort local_id, BinaryOp Op){ | |
const ushort lane_id = local_id % 32; | |
T temp = simd_shuffle_up(value, 1); | |
if (lane_id >= 1) value = Op(value,temp); | |
temp = simd_shuffle_up(value, 2); | |
if (lane_id >= 2) value = Op(value,temp); | |
temp = simd_shuffle_up(value, 4); | |
if (lane_id >= 4) value = Op(value,temp); | |
temp = simd_shuffle_up(value, 8); | |
if (lane_id >= 8) value = Op(value,temp); | |
temp = simd_shuffle_up(value, 16); | |
if (lane_id >= 16) value = Op(value,temp); | |
if (SCAN_TYPE == SCAN_TYPE_EXCLUSIVE){ | |
temp = simd_shuffle_up(value, 1); | |
value = (lane_id == 0) ? 0 : temp; | |
} | |
return value; | |
} | |
template<ushort BLOCK_SIZE, int SCAN_TYPE, typename BinaryOp, typename T> static T | |
ThreadgroupPrefixScanStoreSum(T value, thread T& inclusive_sum, threadgroup T* shared, const ushort local_id, BinaryOp Op) { | |
shared[local_id] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
if (local_id < 32){ | |
T partial_sum = ThreadScan<BLOCK_SIZE / 32, SCAN_TYPE>(&shared[local_id * (BLOCK_SIZE / 32)], Op); | |
T prefix = SimdgroupScan<SCAN_TYPE_EXCLUSIVE>(partial_sum, local_id, Op); | |
ThreadUniformApply<BLOCK_SIZE / 32>(&shared[local_id * (BLOCK_SIZE / 32)], prefix, Op); | |
if (local_id == 31) shared[0] = prefix + partial_sum; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
if (SCAN_TYPE == SCAN_TYPE_INCLUSIVE) value = (local_id == 0) ? value : shared[local_id]; | |
else value = (local_id == 0) ? 0 : shared[local_id]; | |
inclusive_sum = shared[0]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return value; | |
} | |
template<ushort BLOCK_SIZE, int SCAN_TYPE, typename BinaryOp, typename T> static T | |
ThreadgroupPrefixScan(T value, threadgroup T* shared, const ushort local_id, BinaryOp Op) { | |
// load values into shared memory | |
shared[local_id] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// rake over shared mem | |
if (local_id < 32){ | |
T partial_sum = ThreadScan<BLOCK_SIZE / 32, SCAN_TYPE>(&shared[local_id * (BLOCK_SIZE / 32)], Op); | |
T prefix = SimdgroupScan<SCAN_TYPE_EXCLUSIVE>(partial_sum, local_id, Op); | |
ThreadUniformApply<BLOCK_SIZE / 32>(&shared[local_id * (BLOCK_SIZE / 32)], prefix, Op); | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
value = shared[local_id]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return value; | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Discontinuous regions functions | |
/////////////////////////////////////////////////////////////////////////////// | |
template <ushort BLOCK_SIZE, typename T> static uchar | |
FlagHeadDiscontinuity(const T value, threadgroup T* shared, const ushort local_id){ | |
shared[local_id] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
uchar result = (local_id == 0) ? 1 : shared[local_id] != shared[local_id - 1]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return result; | |
} | |
template <ushort BLOCK_SIZE, typename T> static uchar | |
FlagTailDiscontinuity(const T value, threadgroup T* shared, const ushort local_id){ | |
shared[local_id] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
uchar result = (local_id == BLOCK_SIZE - 1) ? 1 : shared[local_id] != shared[local_id + 1]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return result; | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Sorting Functions | |
/////////////////////////////////////////////////////////////////////////////// | |
template <ushort BLOCK_SIZE, typename T> static T | |
SortByBit(const T value, threadgroup uint* shared, const ushort local_id, const uchar current_bit){ | |
// extract the value of the digit | |
uchar mask = ValueToKeyAtBit<2>(value, current_bit); | |
// 2-way scan | |
uchar2 partial_sum; | |
uchar2 scan = {0}; | |
scan[mask] = 1; | |
scan = ThreadgroupPrefixScanStoreSum<BLOCK_SIZE, SCAN_TYPE_EXCLUSIVE>(scan, | |
partial_sum, | |
reinterpret_cast<threadgroup uchar2*>(shared), | |
local_id, | |
SumOp<uchar2>()); | |
// make offsets from the partial sums | |
ushort2 offset; | |
offset[0] = 0; | |
offset[1] = offset[0] + partial_sum[0]; | |
shared[scan[mask] + offset[mask]] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// read new value from shared | |
T result = shared[local_id]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return result; | |
} | |
template <ushort BLOCK_SIZE, typename T> static T | |
SortByTwoBits(const T value, threadgroup uint* shared, const ushort local_id, const uchar current_bit){ | |
uchar mask = ValueToKeyAtBit<4>(value, current_bit); | |
// 4-way scan | |
uchar4 partial_sum; | |
uchar4 scan = {0}; | |
scan[mask] = 1; | |
scan = ThreadgroupPrefixScanStoreSum<BLOCK_SIZE, SCAN_TYPE_EXCLUSIVE>(scan, | |
partial_sum, | |
reinterpret_cast<threadgroup uchar4*>(shared), | |
local_id, | |
SumOp<uchar4>()); | |
// make offsets from the partial sums | |
ushort4 offset; | |
offset[0] = 0; | |
offset[1] = offset[0] + partial_sum[0]; | |
offset[2] = offset[1] + partial_sum[1]; | |
offset[3] = offset[2] + partial_sum[2]; | |
shared[scan[mask] + offset[mask]] = value; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// read new value from shared | |
T result = shared[local_id]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
return result; | |
} | |
template <ushort BLOCK_SIZE, typename T, ushort RADIX> static T | |
PartialRadixSort(const T value, threadgroup uint* shared, const ushort local_id, const ushort current_digit){ | |
T result = value; | |
ushort current_bit = current_digit * RadixToBits(RADIX); | |
const ushort last_bit = min(current_bit + RadixToBits(RADIX), (ushort)sizeof(T) * 8); | |
while (current_bit < last_bit){ | |
if (last_bit - current_bit > 1){ | |
result = SortByTwoBits<BLOCK_SIZE>(result, shared, local_id, current_bit); | |
current_bit += 2; | |
}else{ | |
result = SortByBit<BLOCK_SIZE>(result, shared, local_id, current_bit); | |
current_bit += 1; | |
} | |
} | |
return result; | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
// MARK: - Kernels | |
/////////////////////////////////////////////////////////////////////////////// | |
template<ushort BLOCK_SIZE, ushort GRAIN_SIZE, ushort RADIX, typename T> kernel void | |
MakeHistogramOfPlaceValuesKernel(device uint* output_data, | |
device const T* input_data, | |
constant uint& n, | |
constant uint& current_digit, | |
uint group_id [[threadgroup_position_in_grid]], | |
uint grid_size [[threadgroups_per_grid]], | |
ushort local_id [[thread_position_in_threadgroup]]) { | |
static_assert((BLOCK_SIZE % 32) == 0, "ERROR - BLOCK_SIZE must be a multiple of the execution width"); | |
static_assert(IsPowerOfTwo(RADIX), "ERROR - RADIX must be a power of 2"); | |
uint base_id = group_id * BLOCK_SIZE * GRAIN_SIZE; | |
// load data into registers | |
T values[GRAIN_SIZE]; | |
if (DISABLE_BOUNDS_CHECK){ | |
LoadBlockedLocalFromGlobal(values, &input_data[base_id], local_id); | |
} else { | |
LoadBlockedLocalFromGlobal(values, &input_data[base_id], local_id, n - base_id, numeric_limits<T>::max()); | |
} | |
// zero out the shared memory | |
threadgroup uint histogram[RADIX]; | |
for (ushort i = 0; i < (RADIX + BLOCK_SIZE - 1) / BLOCK_SIZE; i++){ | |
if (local_id + i * BLOCK_SIZE < RADIX) histogram[local_id + i * BLOCK_SIZE] = 0; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// iterate over values to update the histogram using an atomic add operation | |
volatile threadgroup atomic_uint* atomic_histogram = reinterpret_cast<volatile threadgroup atomic_uint*>(histogram); | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
uchar key = ValueToKeyAtDigit<RADIX>(values[i], current_digit); | |
if (DISABLE_BOUNDS_CHECK){ | |
atomic_fetch_add_explicit(&atomic_histogram[key], 1, memory_order_relaxed); | |
} else { | |
uint32_t predicate = (base_id + local_id * GRAIN_SIZE + i < n ) ? 1 : 0; // for blocked reading | |
atomic_fetch_add_explicit(&atomic_histogram[key], predicate, memory_order_relaxed); | |
} | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// store histogram to global in column major format (striped) | |
for (ushort i = 0; i < (RADIX + BLOCK_SIZE - 1) / BLOCK_SIZE; i++){ | |
if (local_id + i * BLOCK_SIZE < RADIX){ | |
output_data[grid_size * (local_id + i * BLOCK_SIZE) + group_id] = histogram[local_id + i * BLOCK_SIZE]; | |
} | |
} | |
} | |
template [[host_name("make_histogram_int32")]] kernel void | |
MakeHistogramOfPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,int>(device uint*, device const int*,constant uint&,constant uint&, uint,uint,ushort); | |
MakeHistogramOfPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,ushort>(device uint*, device const ushort*,constant uint&,constant uint&, uint,uint,ushort); | |
template [[host_name("make_histogram_uint32")]] kernel void | |
MakeHistogramOfPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,uint>(device uint*, device const uint*,constant uint&,constant uint&, uint,uint,ushort); | |
template<ushort BLOCK_SIZE, ushort GRAIN_SIZE, ushort RADIX, typename T> kernel void | |
ReorderByPlaceValuesKernel(device T* output_data, | |
device const T* input_data, | |
constant uint& n, | |
device const uint* offsets_data, | |
constant uint& current_digit, | |
uint group_id [[threadgroup_position_in_grid]], | |
uint grid_size [[threadgroups_per_grid]], | |
ushort local_id [[thread_position_in_threadgroup]]) { | |
uint base_id = group_id * BLOCK_SIZE * GRAIN_SIZE; | |
// load data into registers | |
T values[GRAIN_SIZE]; | |
if (DISABLE_BOUNDS_CHECK){ | |
LoadStripedLocalFromGlobal(values, &input_data[base_id], local_id, BLOCK_SIZE); | |
} else { | |
LoadStripedLocalFromGlobal(values, &input_data[base_id], local_id, BLOCK_SIZE, n - base_id, metal::numeric_limits<T>::max()); | |
} | |
// sort striped values by threadgroup | |
threadgroup uint shared_data[BLOCK_SIZE]; | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
values[i] = PartialRadixSort<BLOCK_SIZE, T, RADIX>(values[i], shared_data, local_id, current_digit); | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
threadgroup uint global_offset[RADIX]; | |
if (local_id < RADIX){ | |
global_offset[local_id] = offsets_data[grid_size * local_id + group_id]; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// write to global using offsets in the histogram | |
uint indexes[GRAIN_SIZE]; | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
// get local offset by scan of head flags of the range of digits | |
uchar key = ValueToKeyAtDigit<RADIX>(values[i], current_digit); | |
uchar flag = FlagHeadDiscontinuity<BLOCK_SIZE>(key, reinterpret_cast<threadgroup uchar*>(shared_data), local_id); | |
ushort local_offset = local_id - ThreadgroupPrefixScan<BLOCK_SIZE, SCAN_TYPE_INCLUSIVE>(flag ? (ushort)local_id : (ushort)0, | |
reinterpret_cast<threadgroup ushort*>(shared_data), | |
local_id, | |
MaxOp<T>()); | |
indexes[i] = local_offset + global_offset[key]; | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
// update the global offsets - put flags into registers, then update indexes and offsets | |
flag = FlagTailDiscontinuity<BLOCK_SIZE>(key, reinterpret_cast<threadgroup uchar*>(shared_data), local_id); | |
if (flag){ | |
global_offset[key] += local_offset + 1; | |
} | |
threadgroup_barrier(mem_flags::mem_threadgroup); | |
} | |
// scatter to global | |
if (DISABLE_BOUNDS_CHECK){ | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
output_data[indexes[i]] = values[i]; | |
} | |
} else { | |
for (ushort i = 0; i < GRAIN_SIZE; i++){ | |
if (indexes[i] < n) { | |
output_data[indexes[i]] = values[i]; | |
} | |
} | |
} | |
} | |
template [[host_name("reorder_int32")]] kernel void | |
ReorderByPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,int> (device int*, device const int*, constant uint&, device const uint*, constant uint& t, uint, uint, ushort); | |
template [[host_name("reorder_uint32")]] kernel void | |
ReorderByPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,uint> (device uint*, device const uint*, constant uint&, device const uint*, constant uint& t, uint, uint, ushort); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment