How We Improved the Concurrency and Scalability of Our Redis Rate Limiting System

September 21, 2021

,
See Rockset
in action

Get a product tour with a Rockset engineer

Background

Rate limiting is a technique used to protect services from overload. In addition, it can be used to prevent starvation of a multi-tenant resource by a few very large customers. At Rockset, we primarily use rate limiting to protect our:

  1. metadata store from overload caused by too many API requests.
  2. log store from filling up due to mismatched input and output rates
  3. control plane from too many state transitions.

We use Redisson RateLimiter which uses Redis under the hood to track rate usage. At a very basic level, our usage of the library looks like this (omitting specific business logic for better readability):

class RedisRateLimiter {
  private final RRateLimiter rateLimitService = ...;

  public boolean isNotRateLimited(String key, int requestedTokens) {
      return rateLimitService.acquire(key, requestedTokens);
  }
}

Let's not dive into the details of RRateLimiter, but suffice it to say that this makes a network call to Redis. RedisRateLimiter.acquire will return true if requestedTokens would not exceed your rate limit and false otherwise.

Problem

Recently, we saw that due to many requests to Redis, the CPU on our Redis cluster was getting close to 100%. The first thing we tried was vertically scaling up our Redis instance to buy us time. However, vertical scaling has its own limits and every few weeks we would end up with another surge in Redis CPU.

We also noticed that Redisson uses Lua scripting on the server side and noticed that lua compilation was taking up a decent chunk of CPU time. Another low hanging fruit we tried was configuring Redisson to cache lua compilation on the server side, reducing CPU time spent on this task. Since this was a simple config change, it didn’t require a code deploy and was easy to get out.

Apart from vertical scaling and improving configuration, we brainstormed a few other approaches to the problem:

  1. We could shard Redis over the rate limit keys to spread the load and horizontally scale.
  2. We could queue rate limit requests locally and have a single thread that periodically (i.e. every 50ms) takes n items off the queue and requests a larger batch of tokens from Redis.
  3. We could proactively reserve larger batches of tokens and cache them locally. When a request for tokens comes in, try returning from the local cache. If that doesn't exist, go fetch a larger batch. This is analogous to Malloc not making a sys call every time memory is requested and instead reserving larger chunks that it manages.

Horizontally scaling Redis by sharding is a great long-term solution; it is probably something we’re going to end up doing at some point.

The problem with the second approach is it raises a few complexities: How frequently does the thread pull from the queue and poll? Do you cap the size of the queue and if so, what happens if the queue is full? How do you even set the cap on the queue? What if Redis has 50 tokens and we batch 10 requests each needing 10 tokens (asking Redis for a total of 100 tokens)? Ideally 5 requests should succeed, but in reality all 10 would fail. These problems are solvable, but would make the implementation quite complex. Thus, we ended up implementing the third solution.

As shown towards the end of the post, this implementation reduced Redis connections on rate limit calls by 96%. The rest of this post will explore how we implemented the third approach. It goes into some of the pitfalls, complexities, and things to consider when working on a batch-oriented solution such as this one.

Implementation

Note that code presented in this blog is in Java. Not all error handling is shown for simplicity. Also, I will reference a now() method which simply returns the unix timestamp in seconds from epoch.

Let's start simple:

class RedisRateLimiter {
  private final RRateLimiter rateLimitService = ...;
  private final long batchSize = ...;
  private final long timeWindowSecs = ...;
  private long reservedTokens = 0;
  private long expirationTs = 0;

  public boolean isNotRateLimited(String key, int requestedTokens) {
    // In this case, we might as well make a direct call to
    // simplify things.
    if (requestedTokens > batchSize) {
      return rateLimitService.acquire(key, requestedTokens);
    }

    if (reservedTokens >= requestedTokens && expirationTs <= now()) {
      reservedTokens -= requestedTokens;
      return true;
    }

    if (rateLimitService.acquire(key, batchSize)) {
      reservedTokens = batchSize - requestedTokens;
      expirationTs = now() + timeWindowSecs;
      return true;
    }

    return false;
  }
}

This code looks fine upon first glance, but what happens if multiple threads need to call isNotRateLimited at the same time? The above code is certainly not thread safe. I will leave as an exercise to the reader why making reservedTokens into an Atomic variable won't solve the problem (although do let us know if you come up with a clever lock-free solution). If Atomics won't work, we can try using Locks instead:

class RedisRateLimiter {
  private final RRateLimiter rateLimitService = ...;
  private final long batchSize = ...;
  private final long timeWindowSecs = ...;
  private final Lock lock = new ReentrantLock();
  private long reservedTokens = 0;
  private long expirationTs = 0;

  public boolean isNotRateLimited(String key, int requestedTokens) {
    // In this case, we might as well make a direct call to
    // simplify things.
    if (requestedTokens > batchSize) {
      return rateLimitService.acquire(key, requestedTokens);
    }

    lock.lock();
    try {
      if (reservedTokens >= requestedTokens && expirationTs <= now()) {
        reservedTokens -= requestedTokens;
        return true;
      } else if (expirationTs <= now()) {
        // Use up remaining tokens
        requestedTokens -= reservedTokens;
        reservedTokens = 0;
      }
    } finally {
      // Easy to overlook; don't lock across the network request.
      lock.unlock();
    }

    if (rateLimitService.acquire(key, batchSize)) {
      lock.lock();
      reservedTokens = (batchSize - requestedTokens);
      expirationTs = now() + timeWindowSecs;
      lock.unlock();
      return true;
    }

    return false;
  }
}

While at first glance this looks correct, there is one subtle problem with it. What happens if multiple threads see there aren't enough reservedTokens? Let's say reservedTokens is 0, our batchSize is 100, and 5 threads request 20 tokens each concurrently.

All 5 threads will see that there aren't enough reserved tokens and each will fetch 100 tokens. Now, this machine is left with 450 reservedTokens and 5x too many requests to the external store. Can we do better? All we really need is for one thread to go and fetch a batch and then the other 4 threads can just utilize that batch. 1 network call, and fewer wasted tokens.

With some booleans and condition variables, we can pretty easily achieve this. If you're unfamiliar with how condition variables work, check out the java docs; most languages will have some sort of condition variable implementation as well. Here's the code:

class RedisRateLimiter {
  private final RRateLimiter rateLimitService = ...;
  private final long batchSize = ...;
  private final long timeWindowSecs = ...;
  private final Lock lock = new ReentrantLock();
  private final Condition fetchCondition = lock.newCondition();
  private boolean fetchInProgress = false;
  private long reservedTokens = 0;
  private long expirationTs = 0;

  public boolean isNotRateLimited(String key, int requestedTokens) {
    // In this case, we might as well make a direct call to
    // simplify things.
    if (requestedTokens > batchSize) {
      return rateLimitService.acquire(key, requestedTokens);
    }

    boolean doFetch = false;
    lock.lock();
    try {
      if (reservedTokens >= requestedTokens && expirationTs <= now()) {
        reservedTokens -= requestedTokens;
        return true;
      } else if (expirationTs <= now()) {
        requestedTokens -= reservedTokens;
        reservedTokens = 0;
      }

      if (fetchInProgress) {
        // Thread is already fetching; let's wait for it to finish.
        fetchCondition.await();
        if (reservedTokens >= requestedTokens) {
          reservedTokens -= requestedTokens;
          return true;
        }
        return false;
      } else {
        doFetch = true; // This thread should fetch the batch
        fetchInProgress = true; // Avoid other threads from fetching.
      }
    } finally {
      lock.unlock();
    }

    if (doFetch) {
      boolean acquired = rateLimitService.acquire(key, batchSize);
      lock.lock();
      if (acquired) {
        reservedTokens = (batchSize - requestedTokens);
        expirationTs = now() + timeWindowSecs;
      }
      fetchCondition.signalAll(); // Wake up waiting threads
      lock.unlock();
      return acquired;
    }

    return false;
  }
}

Now, we will only ever have one thread at a time fetching a batch. While the code is logically correct, we might end up rate limiting a thread too aggressively:

Let's say our batch size is 100 and we have 5 threads requesting 25 tokens each concurrently. The first thread (call it T1) will fetch the batch from the external service. The other 4 threads will wait on the condition variable. However, the 5th thread will have waited for no reason because the first 4 threads will use up all the tokens in the fetched batch. Instead, it might have been better to either:

  1. Immediately return false for the 5th thread (this will rate limit too aggressively)
  2. Or have the 5th thread make a direct call to the external service, not waiting on the first thread.

The second solution is implemented below:

class RedisRateLimiter {
  private final RRateLimiter rateLimitService = ...;
  private final long batchSize = ...;
  private final long timeWindowSecs = ...;
  private final Lock lock = new ReentrantLock();
  private final Condition fetchCondition = lock.newCondition();
  private boolean fetchInProgress = false;
  private long reservedTokens = 0;
  private long expirationTs = 0;
  // Number of tokens that waiting threads will use up.
  private long unreservedFetchTokens = 0;
  // Used by waiting threads to determine if the fetch they are
  // waiting for succeeded or not.
  private boolean didFetchSucceed = false;

  public boolean isNotRateLimited(String key, int requestedTokens) {
    // In this case, we might as well make a direct call to
    // simplify things.
    if (requestedTokens > batchSize) {
      return rateLimitService.acquire(key, requestedTokens);
    }

    boolean doFetch = false;
    lock.lock();
    try {
      if (reservedTokens >= requestedTokens && expirationTimesatmp <= now()) {
        reservedTokens -= requestedTokens;
        return true;
      } else if (expirationTimestamp <= now()) {
        requestedTokens -= reservedTokens;
        reservedTokens = 0;        
      }

      if (fetchInProgress) {
        if (unreservedFetchTokens >= requestedTokens) {
          // Reserve your spot in line
          unreservedFetchTokens -= requestedTokens;
          fetchCondition.await();
          // If we get here and the fetch succeeded, then we
          // are fine.
          return didFetchSucceed;
        }
      } else {
        doFetch = true;
        fetchInProgress = true;
        unreservedFetchTokens = batch - requestedTokens;
      }
    } finally {
      lock.unlock();
    }

    if (doFetch) {
      boolean acquired = rateLimitService.acquire(key, batchSize);
      lock.lock();
      didFetchSucceed = acquired;
      if (acquired) {
        reservedTokens = unreservedFetchTokens;
        expirationTs = now() + timeWindowSecs;
      }
      fetchCondition.signalAll(); // Wake up waiting threads
      lock.unlock();
      return acquired;
    }

    // If we get here, it means there weren't enough
    // unreservedFetchTokens. Let's just make our own
    // call rather than waiting in line.
    return rateLimitService.acquire(key, tokensRequested);
  }
}

Finally, we've arrived at an acceptable solution. In practice, the lock contention should be minimal as we are only setting a few primitive values. But, as with anything, you should benchmark this solution for your use case and see if it makes sense.

Setting the batch size

One remaining question is how to set batchSize. There is a tradeoff here: If batchSize is too low, the number of requests to Redis will approach the number of requests to isNotRateLimited. If batchSize is too high, hosts will reserve too many tokens, starving out other hosts. One thing to consider is whether these hosts can be auto scaled. If so, once numHosts * batchSize exceeds the rate limit, other hosts will start getting starved out even if the number of requests is under the rate limit.

To address some of this, it would be fascinating to explore using a dynamically set batch size. If this machine used up the entire last batch, maybe it can request 1.5x the batch next time (with a cap of course). Alternatively, if batches are going to waste, perhaps only ask for half the batch next time.

Results

As an initial starting point, we set the batchSize to be 1/1000 of the rate limit for a given resource. For our workload, this resulted in ~4% of rate limit requests going to Redis, a massive improvement. This can be seen in the chart below, where the x-axis is time and the y-axis is percent of requests hitting Redis:

how-we-improved-the-concurrency-and-scalability-of-our-redis-rate-limiting - figure1

Improving our rate limiting at Rockset is an ongoing process and this probably won’t be the last improvement we need to make in this area. Stay tuned for more. And if you’re interested in solving these types of problems, we are hiring!

A quick aside

As an aside, the following code has a very subtle concurrency bug. Can you spot it?

class RedisRateLimiter {
  private final RRateLimiter rateLimitService = ...;
  private final long batchSize = ...;
  private final long timeWindowSecs = ...;
  private final Lock lock = new ReentrantLock();
  private final Condition fetchCondition = lock.newCondition();
  private boolean fetchInProgress = false;
  private long reservedTokens = 0;
  private long expirationTs = 0;
  // Number of tokens that waiting threads will use up.
  private long unreservedFetchTokens = 0;

  public boolean isNotRateLimited(String key, int requestedTokens) {
    // In this case, we might as well make a direct call to
    // simplify things.
    if (requestedTokens > batchSize) {
      return rateLimitService.acquire(key, requestedTokens);
    }

    boolean doFetch = false;
    lock.lock();
    try {
      if (reservedTokens >= requestedTokens) {
        reservedTokens -= requestedTokens;
        return true;
      } else if (expirationTimestamp <= now()) {
        requestedTokens -= reservedTokens;
        reservedTokens = 0;        
      }

      if (fetchInProgress) {
        if (unreservedFetchTokens >= requestedTokens) {
          // Reserve your spot in line
          unreservedFetchTokens -= requestedTokens;
          fetchCondition.await();
          if (reservedTokens >= requestedTokens) {
            reservedTokens -= requestedTokens;
            return true;
          }
          return false;
        }
      } else {
        doFetch = true;
        fetchInProgress = true;
        unreservedFetchTokens = batch - requestedTokens;
      }
    } finally {
      lock.unlock();
    }

    if (doFetch) {
      boolean acquired = rateLimitService.acquire(key, batchSize);
      lock.lock();
      if (acquired) {
        reservedTokens = (batchSize - requestedTokens);
        expirationTs = now() + timeWindowSecs;
      }
      fetchCondition.signalAll(); // Wake up waiting threads
      lock.unlock();
      return acquired;
    }

    // If we get here, it means there weren't enough
    // unreservedFetchTokens. Let's just make our own
    // call rather than waiting in line.
    return rateLimitService.acquire(key, tokensRequested);
  }
}

Hint: Even if rateLimitService.acquire always returned true, you can end up in situations where isNotRateLimited returns false.