Tuesday, March 16, 2010

Variants of CompareAndSwap  

I ran across an interesting issue the other day while trying to make a cross-platform atomics header work on Mac OS X.

Compare-and-Swap (hereafter, CAS) is one of the fundamental atomic primitives: given a working CAS you can implement any other atomic operation.

But there are in fact two common variants of CAS:

    // Returns the old value from *ptr
    int ValueCAS(volatile int *ptr, int compare, int swap);

    // Returns true if a swap occurred
    bool BoolCAS(volatile int *ptr, int compare, int swap);

That's OK, though, because it's theoretically possible to convert either one into the other.

It's trivial to implement BoolCAS in terms of ValueCAS. Here's how you do it:

    // GOOD: this works correctly.
    bool BoolCAS(volatile int *ptr, int compare, int swap) {
        return compare == ValueCAS(ptr,compare,swap);
    }

But for what I was doing, I needed to go the other way around. Mac OS X provides BoolCAS functions in <libkern/OSAtomic.h>, and I needed to convert those to ValueCAS functions.

It turns out to be subtly difficult to go the other way around, and I thought I'd share the reasons why.

First attempt

My first attempt was fairly simple, using just an extra load and branch:

    // BAD: do not use! Broken!
    int ValueCAS(volatile int *ptr, int compare, int swap) {
        if (BoolCAS(ptr,compare,swap))
            return compare;
        else
            return *ptr;
    }

The idea is that if the CAS worked, then the old value of *ptr must have been the same as compare, so we can just return that. That's absolutely correct.

However, if the CAS failed, what was the previous value? I was naïvely re-reading *ptr. But my unit tests (oh yes, this is the kind of thing you unit-test thoroughly) showed that very rarely — once per several million iterations — it would fail to do the right thing.

What was going wrong?

The problem comes down to how callers detect success and failure from ValueCAS. If ValueCAS returns compare, the CAS is assumed to have succeeded. If ValueCAS returns any other value, the CAS is assumed to have failed.

Once we recognize this, we can see that the naïve implementation above has not one, but two failure cases:

  1. Spurious CAS failure. This can happen on nearly any system, for a variety of reasons. I'd rather not get too deep into explanations of CPU behavior here, but the typical example might be if an interrupt occurred at exactly the wrong moment. Even if *ptr == compare when the CAS runs, a spurious failure might exit the CAS without changing the value of *ptr. This would cause ValueCAS to incorrectly report success!
  2. Race between CAS and reload. If more than one thread is atomically modifying *ptr, then there's a window between the CAS and the reload where *ptr could be changed to anything at all — including getting set back to compare. Again, this would cause ValueCAS to incorrectly report success!

Another approach

One tempting way to fix it is to simply lie about the old value of *ptr. I didn't actually try this, but I'll admit I considered it for a moment:

    // BAD: do not use! Semantically incorrect!
    int ValueCAS(volatile int *ptr, int compare, int swap) {
        if (BoolCAS(ptr,compare,swap))
            return compare;
        else
            return (compare-1); // or any value other than compare
    }

But this is wrong! It breaks the semantics of the CAS. CAS is supposed to return the old value. But the value this function returns is totally bogus, and may never have existed at *ptr. To help imagine why this is bad, consider what might happen if the caller is using the value as a bitmask. We might wind up returning an illegal mask with bits that had never been set. This could have unintended, ugly, subtle side-effects.

Third time's the charm

As far as I can tell, the best way to fix it is by adding both an extra load and two branches:

    // GOOD: works correctly
    int ValueCAS(volatile int *ptr, int compare, int swap) {
        int old;
        do {
            if (BoolCAS(ptr,compare,swap))
                return compare;
            old = *ptr;
        } while (old == compare); // never return compare if CAS failed
        return old;
    }

And sure enough, it's a little heavy-handed, but it works. Proper branch hinting will help, but I'm still a bit surprised that it's so expensive (relatively speaking) to do this conversion.

Can anyone come up with a better way?