Created
December 5, 2020 00:11
-
-
Save dondragmer/0c0b3eed0f7c30f7391deb11121a5aa1 to your computer and use it in GitHub Desktop.
A very fast GPU sort for sorting values within a wavefront
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Buffer<uint> Input; | |
RWBuffer<uint> Output; | |
//returns the index that this value should be moved to to sort the array | |
uint CuteSort(uint value, uint laneIndex) | |
{ | |
uint smallerValuesMask = 0; | |
uint equalValuesMask = ~0; | |
//don't need to test every bit if your value is constrained to a smaller range | |
for (int bit = 0; bit < 32; bit++) | |
{ | |
bool isBitSet = value & (1 << bit); | |
uint bitSetMask = WaveActiveBallot(isBitSet); | |
if(isBitSet) | |
{ | |
smallerValuesMask |= ~bitSetMask; | |
equalValuesMask &= bitSetMask; | |
} | |
else | |
{ | |
smallerValuesMask &= ~bitSetMask; | |
equalValuesMask &= ~bitSetMask; | |
} | |
} | |
//count up all the lanes with values that should be in front of this one | |
uint numSmallerThanThis = countbits(smallerValuesMask); | |
uint numEqualBeforeThis = countbits((equalValuesMask << (31 - laneIndex)) << 1); | |
return numSmallerThanThis + numEqualBeforeThis; | |
} | |
uint ShuffleTo(uint value, uint dstIndex, uint laneIndex) | |
{ | |
uint equalIndexMask = ~0; | |
for (int bit = 0; bit < 5; bit++) | |
{ | |
uint bitSetMask = WaveActiveBallot(dstIndex & (1 << bit)); | |
equalIndexMask &= (laneIndex & (1 << bit)) ? bitSetMask : ~bitSetMask; | |
} | |
uint laneWithOurValue = firstbitlow(equalIndexMask); | |
return WaveReadLaneAt(value, laneWithOurValue); | |
} | |
[numthreads(1024, 1, 1)] | |
void CuteSortTest(uint3 id : SV_DispatchThreadID) | |
{ | |
uint value = Input[id.x]; | |
uint outputIndex = CuteSort(value, id.x & 0x1F); | |
/* | |
//alt version which presorts before outputting | |
value = ShuffleTo(value, outputIndex, id.x & 0x1F); | |
Output[id.x] = value; | |
*/ | |
outputIndex += id.x & ~0x1F; | |
Output[outputIndex] = value; | |
} | |
//returns the value at this lane's index in the sorted array, slower than cute sort in most circumstances | |
uint BitonicSort(uint value, uint laneIndex) | |
{ | |
for (uint sortSize = 2; sortSize <= 32; sortSize = sortSize << 1) | |
{ | |
bool reverseSequence = laneIndex & sortSize; | |
for (uint stride = sortSize >> 1; stride > 0; stride = stride >> 1) | |
{ | |
bool ascending = (bool) (laneIndex & stride) == reverseSequence; | |
uint other = WaveReadLaneAt(value, laneIndex ^ stride); | |
if ((other < value) == ascending) | |
{ | |
value = other; | |
} | |
} | |
} | |
return value; | |
} | |
[numthreads(1024, 1, 1)] | |
void BitonicSortTest(uint3 id : SV_DispatchThreadID) | |
{ | |
uint value = Input[id.x]; | |
value = BitonicSort(value, id.x & 0x1F); | |
Output[id.x] = value; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment