I figured out what I think are a few unusual optimizations, and while I'm not really sure that any of them are new, the combination makes my code run pretty fast.
In particular, histograms don't change when you change the order, so I just do all the histogramming in one pass through the data. One read builds several histograms.
So if you histogram a floating point number in four 8-bit passes, you can build four histograms from one pass through the data, rather than four.
This would mean 5 passes through the data, rather than the usual 8.
If you look at them in binary, single precision floating point numbers have two problems that keeps them from being directly sortable.
[sign] [exponent] [mantissa]
First, the sign bit is set when the value is negative, which means that all negative numbers are bigger than positive ones. You could
just flip it, of course (I thought this was all I had to do at first), but there's still another problem.
But the second problem is that the values are signed-magnitude, so "more negative" floating point numbers actually look bigger to a normal bitwise comparison.
To fix this, we have to do some bit-twiddling in integer. It turns out that flipping the exponent inverts the order of the exponents (flips them low to high), and flipping the mantissa does the same. Basically, the exponent is a "range" of numbers, and we flip the orders of these ranges. And then the mantissa is the numbers in each range, and we flip these as well.
We're supposed to call this a "bijective mapping" from 32-bit integers to themselves, which means, for every 32-bit number, there's another unique one that we map to, and we can invert the mapping to get back exactly where we started.
It turns out that when you say "flip" you can also say, "xor with 1's" -- for instance, if you had an 8-bit number "x", to compute "255-x" you could simply use (255 ^ x) instead -- without any of the evils of carrying like addition has. (In case it wasn't clear, 255 is eight "1s" in a row.)
So, to fix our floating point numbers, we define the following rules:
If we write this as a single xor, what we want is:
When the sign bit is set, xor with 0xFFFFFFFF (flip every bit)This leads to two routines to convert floating point values to sortable numbers and back again. I call them FloatFlip and IFloatFlip:
When the sign bit is unset, xor with 0x80000000 (flip the sign bit)
static inline uint32 FloatFlip(uint32 f) { uint32 mask = -int32(f >> 31) | 0x80000000; return f ^ mask; } static inline uint32 IFloatFlip(uint32 f) { uint32 mask = ((f >> 31) - 1) | 0x80000000; return f ^ mask; }This shifts the sign bit down 31 places (to make the entire number "0" or "1"), and then either inverts it or subtracts one. In particular, this always makes "0" or "0xFFFFFFFF" which becomes "0x80000000" or "0xFFFFFFFF" after or'ing in a 1.
Works nicely.
An initial implementation used 8-bit histograms, and this 11-bit optimization improves performance by about 40%.
My test case was:
After all the above optimizations, the radix achieves 97 sorts/sec on this test, a quite amazing improvement.
I hope the code is quite readable, so I'm just posting it here for everybody to try. You'd have to do some modifications to make it useful, but this is nice as a raw speed test.
// Radix.cpp: a fast floating-point radix sort demo // // Copyright (C) Herf Consulting LLC 2001. All Rights Reserved. // Use for anything you want, just tell me what you do with it. // Code provided "as-is" with no liabilities for anything that goes wrong. // #include <stdio.h> #include <string.h> #include <stdlib.h> #include <windows.h> // QueryPerformanceCounter // ------------------------------------------------------------------------------------------------ // ---- Basic types typedef long int32; typedef unsigned long uint32; typedef float real32; typedef double real64; typedef unsigned char uint8; typedef const char *cpointer; // ------------------------------------------------------------------------------------------------ // Configuration/Testing // ---- number of elements to test (shows tradeoff of histogram size vs. sort size) const uint32 ct = 65536; // ---- really, a correctness check, not correctness itself ;) #define CORRECTNESS 1 // ---- use SSE prefetch (needs compiler support), not really a problem on non-SSE machines. // need http://msdn.microsoft.com/vstudio/downloads/ppack/default.asp // or recent VC to use this #define PREFETCH 1 #if PREFETCH #include <xmmintrin.h> // for prefetch #define pfval 64 #define pfval2 128 #define pf(x) _mm_prefetch(cpointer(x + i + pfval), 0) #define pf2(x) _mm_prefetch(cpointer(x + i + pfval2), 0) #else #define pf(x) #define pf2(x) #endif // ------------------------------------------------------------------------------------------------ // ---- Visual C++ eccentricities #if _WINDOWS #define finline __forceinline #else #define finline inline #endif // ================================================================================================ // flip a float for sorting // finds SIGN of fp number. // if it's 1 (negative float), it flips all bits // if it's 0 (positive float), it flips the sign only // ================================================================================================ finline uint32 FloatFlip(uint32 f) { uint32 mask = -int32(f >> 31) | 0x80000000; return f ^ mask; } finline void FloatFlipX(uint32 &f) { uint32 mask = -int32(f >> 31) | 0x80000000; f ^= mask; } // ================================================================================================ // flip a float back (invert FloatFlip) // signed was flipped from above, so: // if sign is 1 (negative), it flips the sign bit back // if sign is 0 (positive), it flips all bits back // ================================================================================================ finline uint32 IFloatFlip(uint32 f) { uint32 mask = ((f >> 31) - 1) | 0x80000000; return f ^ mask; } // ---- utils for accessing 11-bit quantities #define _0(x) (x & 0x7FF) #define _1(x) (x >> 11 & 0x7FF) #define _2(x) (x >> 22 ) // ================================================================================================ // Main radix sort // ================================================================================================ static void RadixSort11(real32 *farray, real32 *sorted, uint32 elements) { uint32 i; uint32 *sort = (uint32*)sorted; uint32 *array = (uint32*)farray; // 3 histograms on the stack: const uint32 kHist = 2048; uint32 b0[kHist * 3]; uint32 *b1 = b0 + kHist; uint32 *b2 = b1 + kHist; for (i = 0; i < kHist * 3; i++) { b0[i] = 0; } //memset(b0, 0, kHist * 12); // 1. parallel histogramming pass // for (i = 0; i < elements; i++) { pf(array); uint32 fi = FloatFlip((uint32&)array[i]); b0[_0(fi)] ++; b1[_1(fi)] ++; b2[_2(fi)] ++; } // 2. Sum the histograms -- each histogram entry records the number of values preceding itself. { uint32 sum0 = 0, sum1 = 0, sum2 = 0; uint32 tsum; for (i = 0; i < kHist; i++) { tsum = b0[i] + sum0; b0[i] = sum0 - 1; sum0 = tsum; tsum = b1[i] + sum1; b1[i] = sum1 - 1; sum1 = tsum; tsum = b2[i] + sum2; b2[i] = sum2 - 1; sum2 = tsum; } } // byte 0: floatflip entire value, read/write histogram, write out flipped for (i = 0; i < elements; i++) { uint32 fi = array[i]; FloatFlipX(fi); uint32 pos = _0(fi); pf2(array); sort[++b0[pos]] = fi; } // byte 1: read/write histogram, copy // sorted -> array for (i = 0; i < elements; i++) { uint32 si = sort[i]; uint32 pos = _1(si); pf2(sort); array[++b1[pos]] = si; } // byte 2: read/write histogram, copy & flip out // array -> sorted for (i = 0; i < elements; i++) { uint32 ai = array[i]; uint32 pos = _2(ai); pf2(array); sort[++b2[pos]] = IFloatFlip(ai); } // to write original: // memcpy(array, sorted, elements * 4); } // Simple test of radix int main(int argc, char* argv[]) { uint32 i; const uint32 trials = 100; real32 *a = new real32[ct]; real32 *b = new real32[ct]; for (i = 0; i < ct; i++) { a[i] = real32(rand()) / 2048; if (rand() & 1) { a[i] = -a[i]; } } LARGE_INTEGER s1, s2, f; QueryPerformanceFrequency(&f); QueryPerformanceCounter(&s1); for (i = 0; i < trials; i++) { RadixSort11(a, b, ct); } QueryPerformanceCounter(&s2); real64 hz = real64(f.QuadPart) / (s2.QuadPart - s1.QuadPart); printf("%d elements: %f/s\n", ct, hz * trials); #if CORRECTNESS for (i = 1; i < ct; i++) { if (b[i - 1] > b[i]) { printf("Wrong at %d\n", i); } } #endif delete a; delete b; return 0; }And for trying this at home, a zipfile for you: radix.zip