Writing Redis in Python with asyncio: Part 1

Mon 23 March 2015 by James Saryerwinnie

Python 3.4 featured a brand new library that's been getting a lot of attention: asyncio. For numerous reasons, including the fact that the originator of the pep is Guido himself, the asyncio library is growing in popularity within the python community.

So, I'm thinking that it might be fun to try to use this new asyncio library to write redis in pure python.

The Pitch

I maintain a library, fakeredis, which is a testing library that emulates redis via the redis-py client API. It doesn't have any of the runtime guarantees that redis has (yet), but for the most part, it has the same functional behavior as redis. All of its state is kept in memory, and it does nothing for persisting state to disk. After all, it's meant as a testing library, to avoid having to spin up a real redis server during your python unit tests. Here's fakeredis in a nutshell:

>>> import fakeredis
>>> r = fakeredis.FakeStrictRedis()
>>> r.set('foo', 'bar')
True
>>> r.get('foo')
'bar'
>>> r.lpush('bar', 1)
1
>>> r.lpush('bar', 2)
2
>>> r.lrange('bar', 0, -1)
[2, 1]

Same semantics as the redis-py client, with all state stored in the memory of the process.

So the idea is to use asyncio to provide a server that accepts client connections that can parse the wire protocol for redis. It then figures out the corresponding calls to make into fakeredis, which provide all the functional redis semantics, and takes the return values from fakeredis and constructs the appropriate wire response.

If all goes well I should have something that any redis client can talk to, without knowing they're not actually talking to the real redis server. We will have created a slower, less memory efficient implementation of the redis server without the "required" features like "persistence" or "replication". It's redis, writen in python, using asyncio. Sounds like fun.

If nothing else, we'll learn a little more about asyncio in the process.

Setting Scope

Now first off, I plan for this to be a multipart series.

The scope for this post, part 1, is to get to the point where we can make redis calls for all its basic functionality, which includes the API calls for manipulating data for all of redis's supported types. Perhaps what's more interesting is what I'm leaving out in this post.

What I won't look at in this post is:

  • saving to disk
  • blocking operations, such as BLPOP
  • performance
  • handling slow clients
  • expirations
  • any kind of replication
  • testing

These items will be the subject of future posts. This is a long winded way of me saying that we're going to be taking shortcuts. It'll be ok.

Assumptions

To get the most out of this post, I'm assuming that:

  • You're familiar with redis from an end-user perspective. You know what redis is and you're familiar with the basic commands.
  • You're new to asyncio, but you're not necessarily new to event driven programming.
  • You're using python 3.4 or greater.

Get the Skeleton Up and Running

The very first thing I want to do is get something up and running. It doesn't have to do much, but I want to be able to at least have the server handle a request and return a response, even if it's hardcoded. I'm going to pick the GET command because it's the simplest operation that provides useful functionality. Once we get this running, we'll pick it apart and figure out how it actually works.

So first things first, let's hop over to the asyncio reference docs.

End to End Skeleton

Asyncio appears to have a huge amount of documentation, but most of it is stuff I don't care about right now. The closest thing that looks interesting is this TCP echo server protocol, which shows a basic echo server with asyncio. We should be able to start with the echo server and adapt that to what we want, at least initially. Here's what I came up with after trying to adapt the echo server example above to a hard coded redis GET command.

import asyncio


class RedisServerProtocol(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data):
        message = data.decode()
        if 'GET' in message:
            self.transport.write(b"$3\r\n")
            self.transport.write(b"BAZ\r\n")
        else:
            self.transport.write(b"-ERR unknown command\r\n")


def main(hostname='localhost', port=6379):
    loop = asyncio.get_event_loop()
    coro = loop.create_server(RedisServerProtocol,
                              hostname, port)
    server = loop.run_until_complete(coro)
    print("Listening on port {}".format(port))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print("User requested shutdown.")
    finally:
        server.close()
        loop.run_until_complete(server.wait_closed())
        loop.close()
        print("Redis is now ready to exit.")
    return 0


if __name__ == '__main__':
    main()

Save the above code to a file redis-asyncio and run it. We'll use the redis-cli to verify this has the behavior that we want:

$ ./redis-asyncio &
[1] 96221
Listening on port 6379

$ redis-cli
127.0.0.1:5678> GET foo
"BAZ"
127.0.0.1:5678> GET bar
"BAZ"
127.0.0.1:5678> GET anything
"BAZ"
127.0.0.1:5678> FOOBAR asdf
(error) ERR unknown command

It works!

But How Does it Work?

There's a lot we haven't explained yet.

While I'm going to skip over the get_event_loop and run_until_complete for now, the create_server is interesting. How exactly does this server we create integrate with the RedisServerProtocol we made? For example, how do we go from create_server to calling RedisServerProtocol.connection_made?

What helped me the most was just digging into the source code for asyncio, so let's do that. I've annotated and simplified the code to give you a high level view of what's going on. We'll start with create_server, and keeping going through the various methods until we see our protocol's connection_made method being called.

# These are all methods within an EventLoop class.

@coroutine
def create_server(self, protocol_factory, host=None, port=None,
                  *,
                  family=socket.AF_UNSPEC,
                  flags=socket.AI_PASSIVE,
                  sock=None,
                  backlog=100,
                  ssl=None,
                  reuse_address=None):
    # In this scenario the ``protocol_factory`` maps
    # to the ``RedisServerProtocol`` class object.

    # Create listening socket(s).
    socket = lots_of_code()

    server = Server(self, socket)
    # Once we create a server, we call _start_serving.
    # Note how we're passing along the protocol_factory
    # argument (our ``RedisServerProtocol`` class).
    self._start_serving(protocol_factory, socket, ssl, server)
    return server

def _start_serving(self, protocol_factory, sock,
                   sslcontext=None, server=None):
    # We're registering the _accept_connection method to be called
    # when a new connection is made.  Again notice how we're
    # still passing along our protocol_factory (``RedisServerProtocol``
    # class) object.
    self.add_reader(sock.fileno(), self._accept_connection,
                    protocol_factory, sock, sslcontext, server)

def _accept_connection(self, protocol_factory, sock,
                       sslcontext=None, server=None):
    # Finally!  We're we can see that we instantiate
    # the protocol_factory class to actually get
    # an instance of ``RedisServerProtocol``.
    # We've gone from a class to an instance.  So
    # what about connection_made?  When does this get
    # called?  Down the stack to _make_socket_transport.
    self._make_socket_transport(
        conn, protocol_factory(), extra={'peername': addr},
        server=server)

def _make_socket_transport(self, sock, protocol, waiter=None, *,
                           extra=None, server=None):
    # At least we can see we've gone from protocol_factory to
    # just protocol, so now "protocol" in this scenario is an
    # instance of ``RedisServerProtocol``.
    return _SelectorSocketTransport(self, sock, protocol, waiter,
                                    extra, server)

class _SelectorSocketTransport:
    def __init__(self, loop, sock, protocol, waiter=None,
                 extra=None, server=None):
        super().__init__(loop, sock, protocol, extra, server)
        self._eof = False
        self._paused = False

        self._loop.add_reader(self._sock_fd, self._read_ready)
        # And finally, we see that we ask the event loop to call
        # the connection_made method of our protocol class, and we're
        # passing "self" (The transport object) as an argument to
        # connection_made.
        self._loop.call_soon(self._protocol.connection_made, self)

Recap

So far, we've learned:

  • It looks like the interesting stuff we'll be writing is in the Protocol. To write our own redis server, we're going to flesh out a proper RedisServerProtocol class that understands the redis wire protocol.
  • We get 1 protocol per client connection. Storing state on the protocol will be scoped to the lifetime of that connection.
  • To wire things up, hand the protocol class to the create_server, which is called on an event loop instance. As we saw in the code snippet above in _accept_connection(), the protocol_factory argument is called with no args to create a protocol instance. While a class object works fine for now, we're going to have to use a closure or a factory class to pass arguments to the protocol when it's created.
  • The protocols themselves let you define methods that are invoked by the event loop. That is asyncio will call methods when there's a connection_made(), or there's data_received. Looking at the Protocol classes, there appears to be a few more methods you can implement.

Now that we understand the basics, we can start looking at the redis wire protocol.

Parsing the Wire Protocol

First thing we're going to need to do properly handle requests is protocol parser, this is the code that takes the redis request off the TCP socket and parses it into something meaningful. This code for this isn't that interesting. Reading the docs for the redis wire protocol, it's straightforward to implement.

Now, none of this is optimized yet, but here's a basic implementation of parsing the redis wire protocol. It accepts a byte string, and returns python objects.

def parse_wire_protocol(message):
    return _parse_wire_protocol(io.BytesIO(message))


def _parse_wire_protocol(msg_buffer):
    current_line = msg_buffer.readline()
    msg_type, remaining = chr(current_line[0]), current_line[1:]
    if msg_type == '+':
        return remaining.rstrip(b'\r\n').decode()
    elif msg_type == ':':
        return int(remaining)
    elif msg_type == '$':
        msg_length = int(remaining)
        if msg_length == -1:
            return None
        result = msg_buffer.read(msg_length)
        # There's a '\r\n' that comes after a bulk string
        # so we .readline() to move passed that crlf.
        msg_buffer.readline()
        return result
    elif msg_type == '*':
        array_length = int(remaining)
        return [_parse_wire_protocol(msg_buffer) for _ in range(array_length)]

We're also going to need the inverse of this, something that takes a response from fakeredis and converts it back into bytes that can be sent across the wire. Again, nothing too interesting about this code, but here's what I came up with:

def serialize_to_wire(value):
    if isinstance(value, str):
        return ('+%s' % value).encode() + b'\r\n'
    elif isinstance(value, bool) and value:
        return b"+OK\r\n"
    elif isinstance(value, int):
        return (':%s' % value).encode() + b'\r\n'
    elif isinstance(value, bytes):
        return (b'$' + str(len(value)).encode() +
                b'\r\n' + value + b'\r\n')
    elif value is None:
        return b'$-1\r\n'
    elif isinstance(value, list):
        base = b'*' + str(len(value)).encode() + b'\r\n'
        for item in value:
            base += serialize_to_wire(item)
        return base

Let's try this out:

>>> set_request = b'*3\r\n$3\r\nset\r\n$3\r\nfoo\r\n$3\r\nbar\r\n'
>>> parse_wire_protocol(set_request)
[b'set', b'foo', b'bar']

>>> serialize_to_wire([b'5', b'4', b'3', b'2', b'1'])
b'*5\r\n$1\r\n5\r\n$1\r\n4\r\n$1\r\n3\r\n$1\r\n2\r\n$1\r\n1\r\n'

After calling the parse_wire_protocol we can see that get a list of [command_name, arg1, arg2, ...].

Implementing the Protocol Class

We should have everything we need to make a more realistic RedisServerProtocol class now. We're making the assumption for now that the entire command is provided when data_received is called.

class RedisServerProtocol(asyncio.Protocol):


    def __init__(self, redis):
        self._redis = redis
        self.transport = None

    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data):
        parsed = parse_wire_protocol(data)
        # parsed is an array of [command, *args]
        command = parsed[0].decode().lower()
        try:
            method = getattr(self._redis, command)
        except AttributeError:
            self.transport.write(
                b"-ERR unknown command " + parsed[0] + b"\r\n")
            return
        result = method(*parsed[1:])
        serialized = serialize_to_wire(result)
        self.transport.write(serialized)


class WireRedisConverter(object):
    def __init__(self, redis):
        self._redis = redis

    def lrange(self, name, start, end):
        return self._redis.lrange(name, int(start), int(end))

    def hmset(self, name, *args):
        converted = {}
        iter_args = iter(list(args))
        for key, val in zip(iter_args, iter_args):
            converted[key] = val
        return self._redis.hmset(name, converted)

    def __getattr__(self, name):
        return getattr(self._redis, name)

The most important part here is the data_received method. Note that the first thing we do is take the bytes data we're given and immediately parse that into a python list using our parse_wire_protocol. The next thing we do is try to look for a corresponding method in the WireRedisConverter class based on the command we've been given. The WireRedisConverter class takes the parsed python list we receive from clients and maps that into the appropriate calls into fakeredis. For example:

HMSET myhash field1 "Hello"                           <- redis-cli
['hmset', 'myhash', 'field1', 'Hello']                <- parsed
WireRedisConverter.hmset('myhash', 'field1', 'Hello')
FakeRedis.hmset('myhash', {'field1': 'Hello'})

I've only shown a portion of WireRedisConverter, but there's enough to give you the basic idea of how a python list maps is then mapped to fakeredis calls.

And finally, we serialize the python response back to bytes using serialize_to_wire and write this value out to the transport we received from connection_made.

Wiring Up the Protocol Class

We'll also need to make a change to our main function, mostly in how we wire up the RedisServerProtocol:

def main(hostname='localhost', port=6379):
    loop = asyncio.get_event_loop()
    wrapped_redis = WireRedisConverter(fakeredis.FakeStrictRedis())

    bound_protocol = functools.partial(RedisServerProtocol,
                                       wrapped_redis)
    coro = loop.create_server(bound_protocol,
                              hostname, port)
    server = loop.run_until_complete(coro)
    print("Listening on port {}".format(port))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print("User requested shutdown.")
    finally:
        server.close()
        loop.run_until_complete(server.wait_closed())
        loop.close()
        print("Redis is now ready to exit.")
    return 0

The biggest difference here is that we're using functools.partial so that we can pass in our wrapped fakeredis instance to the RedisServerProtocol class whenever it's created. As we saw earlier, the protocol_factory is called with no args and is expected to return a protocol instance. While we could write a protocol factory class, we're using functools.partial because that's all we need for now.

Testing it Out

And finally, we should have something that vaguely resembles redis. Let's try it out:

$ ./redis-asyncio &
[1] 55470

$ redis-cli
127.0.0.1:6379> set foo bar
OK
127.0.0.1:6379> get foo
"bar"
127.0.0.1:6379> set foo baz
OK
127.0.0.1:6379> get foo
"baz"

127.0.0.1:6379> lpush abc 1
(integer) 1
127.0.0.1:6379> lpush abc 2
(integer) 2
127.0.0.1:6379> lpush abc 3
(integer) 3
127.0.0.1:6379> lrange abc 0 -1
1) "3"
2) "2"
3) "1"

127.0.0.1:6379> hmset myhash field1 "hello" field2 "world"
OK
127.0.0.1:6379> hget myhash field1
"hello"
127.0.0.1:6379> hget myhash field2
"world"

127.0.0.1:6379> sadd myset "Hello"
(integer) 1
127.0.0.1:6379> sadd myset "World"
(integer) 1
127.0.0.1:6379> sadd myset "World"
(integer) 0
127.0.0.1:6379> smembers myset
1) "Hello"
2) "World"

Let's even try talking to ./redis-asyncio using the redis-py module:

>>> import redis
>>> r = redis.Redis()
>>> r.set('foo', 'bar')
True
>>> r.get('foo')
b'bar'

>>> r.lpush('mylist', 1)
1
>>> r.lpush('mylist', 2)
2
>>> r.lrange('mylist', 0, -1)
[b'2', b'1']

>>> r.sadd('myset', 'hello')
1
>>> r.sadd('myset', 'world')
1
>>> r.sadd('myset', 'world')
0
>>> r.smembers('myset')
{b'world', b'hello'}

>>> r.hmset('myhash', {'a': 'b', 'c': 'd'})
True
>>> r.hget('myhash', 'a')
b'b'
>>> r.hget('myhash', 'c')
b'd'

Wrapping Up

In this post, we looked at getting a basic redis implementation up and running using asyncio and fakeredis. We were able to run basic commands such as get, set, lpush, lrange, sadd, smembers, hmset, hget, etc.

In the next post, we'll look at implementing blocking operations such as BLPOP.

And, because you're probably just as curious as I was, here's a benchmark comparison between what we've written and the real redis server. These benchmarks were run on the same machine so it's the relative difference that's interesting to me. I wouldn't read too much into it though.

redis-benchmark -t set -n 200000

redis-server redis-asyncio
====== SET ======
  200000 requests completed
    in 1.52 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

99.97% <= 1 milliseconds
100.00% <= 1 milliseconds
131926.12 requests per second
      
====== SET ======
  200000 requests completed
    in 5.17 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

0.25% <= 1 milliseconds
99.48% <= 2 milliseconds
99.96% <= 3 milliseconds
99.98% <= 4 milliseconds
99.99% <= 5 milliseconds
99.99% <= 6 milliseconds
99.99% <= 7 milliseconds
99.99% <= 8 milliseconds
99.99% <= 9 milliseconds
99.99% <= 10 milliseconds
99.99% <= 11 milliseconds
99.99% <= 12 milliseconds
99.99% <= 13 milliseconds
99.99% <= 14 milliseconds
99.99% <= 15 milliseconds
99.99% <= 16 milliseconds
99.99% <= 18 milliseconds
100.00% <= 19 milliseconds
100.00% <= 20 milliseconds
100.00% <= 21 milliseconds
100.00% <= 22 milliseconds
100.00% <= 24 milliseconds
100.00% <= 25 milliseconds
100.00% <= 26 milliseconds
100.00% <= 27 milliseconds
100.00% <= 29 milliseconds
100.00% <= 30 milliseconds
100.00% <= 32 milliseconds
38654.81 requests per second
      

redis-benchmark -t get -n 200000

redis-server redis-asyncio
====== GET ======
  200000 requests completed
    in 1.53 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

100.00% <= 0 milliseconds
130975.77 requests per second
      
====== GET ======
  200000 requests completed
    in 6.42 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

0.18% <= 1 milliseconds
96.55% <= 2 milliseconds
99.83% <= 3 milliseconds
99.98% <= 4 milliseconds
99.99% <= 5 milliseconds
99.99% <= 6 milliseconds
99.99% <= 7 milliseconds
99.99% <= 8 milliseconds
99.99% <= 9 milliseconds
99.99% <= 10 milliseconds
99.99% <= 11 milliseconds
99.99% <= 12 milliseconds
99.99% <= 13 milliseconds
99.99% <= 14 milliseconds
99.99% <= 15 milliseconds
99.99% <= 17 milliseconds
99.99% <= 18 milliseconds
99.99% <= 19 milliseconds
99.99% <= 21 milliseconds
99.99% <= 22 milliseconds
99.99% <= 23 milliseconds
100.00% <= 24 milliseconds
100.00% <= 26 milliseconds
100.00% <= 27 milliseconds
100.00% <= 29 milliseconds
100.00% <= 31 milliseconds
100.00% <= 32 milliseconds
100.00% <= 34 milliseconds
100.00% <= 35 milliseconds
100.00% <= 37 milliseconds
100.00% <= 39 milliseconds
100.00% <= 41 milliseconds
31157.50 requests per second
      

Comments