I've been writing redis in python using asyncio. I started this project because I wanted to learn more about asyncio and thought that porting an existing project to asyncio would give me an excellent opportunity to learn about this new library.

Part 1 of this series covered how to implement a basic request/response for redis using asyncio. It covered how to use protocols, how they hooked into asyncio, and how you could parse and serialize requests and responses. Since part 1 was first published over a year ago (I know...), a few things have happened:

  1. Python 3.5 added async and await keywords which changed the recommended way for working with coroutines.
  2. I had the opportunity to speak about this topic at EuroPython 2016.

The slides for my talk are available on speakerdeak.

You can also check out the talk here.

If you've read Part 1 of this series, the EuroPython talk covered several additional topics:

  • PUBLISH/SUBSCRIBE
  • BLPOP/BRPOP (blocking queues)

Covering BLPOP/BRPOP also required a quick detour into async and await.

In the next series of posts, I wanted to discuss these topics in more detail, and cover some of the additional topics that I omitted from my talk due to time.

For the remainder of this post, we're going to look at how to implement the PUBLISH and SUBSCRIBE commands in redis using asyncio.

Publish/Subscribe

I'm assuming you're familiar with redis, but here's a quick reminder of the PUBSUB feature in redis, and what we're shooting for in this post:

And here's a video of this in action:

In the video, you see two clients SUBSCRIBE to a channel. Those clients will then block until another client comes along and issues a PUBLISH command to that channel. You can see that when the bottom client issues a PUBLISH command, the two top clients subscribed to the channel receive the published message. The redis docs on pubsub have a more detailed overview of this feature.

Let's look at how to do this in python using asyncio.

Sharing State

Whenever we receive a PUBLISH command from a client, we need to send the message being published to every previously subscribed client. Transports need to be able to talk to other transports. More generally, we need a way to share state across transports.

As a newcomer to the async world, this was one of the hardest things for me to figure out. How are you suppose to share state across connections?

What I found helpful was to first write code that assumed shared state and then figure out how to plumb it all together later. For this PUSUB feature, let's create a PubSub class that allows transports to subscribe and publish to channels:

class PubSub:
    def __init__(self):
        self._channels = {}

    def subscribe(self, channel, transport):
        self._channels.setdefault(channel, []).append(transport)
        return ['subscribe', channel, 1]

    def publish(self, channel, message):
        transports = self._channels.get(channel, [])
        message = serializer.serialize_to_wire(
            ['message', channel, message])
        for transport in transports:
            transport.write(message)
        return len(transports)

In the class above, We maintain a mapping of channel names (which are strings) to transports. Every time a client wants to subscribe to a channel we add them to the list of transports associated with that channel. Whenever a client wants to publish a message we iterate through every transport and write the message being published.

The way we'd use this class is in our RedisServerProtocol class where we'll assume we have an instance of this PubSub class available as the self._pubsub instance variable:

# In the RedisServerProtocol class:

class RedisServerProtocol:
    def __init__(self, pubsub):
        self._pubsub = pubsub

    def data_received(self, data):
        parsed = parser.parse_wire_protocol(data)
        # [COMMAND, arg1, arg2]
        command = parsed[0].lower()
        if command == b'subscribe':
            response = self._pubsub.subscribe(parsed[1], self.transport)
        elif command == b'publish':
            response = self._pubsub.publish(parsed[1], parsed[2])

For this code to work, there can only be a single instance of the PubSub class that's shared across all the incoming connections. We need a way to make sure that whenever we create a protocol instance, we can also inject a shared reference to a PubSub instance.

Let's refresh our memories first. In part 1 of this series, we talked protocols and transports. One of the main takeaways from that post is that every time a client connects to our server, there is a protocol instance and a transport instance associated with that connection. It looks like this:

A protocol factory is used to create a protocol instance which is associated with a single connection. This factory is just a callable that returns an instance of a protocol. Here's how a protocol factory is used in the asyncio code base, asyncio/selector_events.py:

 def _accept_connection2(
     self, protocol_factory, conn, extra, server=None):
     protocol = None
     transport = None
     try:
         protocol = protocol_factory()   # RedisServerProtocol
         waiter = futures.Future(loop=self)
         transport = _SelectorSocketTransport(self, sock, protocol,
                                        waiter, extra, server)
         # ...
     except Exception as exc:
         # ...
         pass

Because a protocol factory is instantiated with no args, we need some other way to bind our PubSub instance to this factory. We could use functools.partial (which is actually what's used in part 1), but I've found that having a distinct class for this has made things easier:

class ProtocolFactory:
    def __init__(self, protocol_cls, *args, **kwargs):
        self._protocol_cls = protocol_cls
        self._args = args
        self._kwargs = kwargs

    def __call__(self):
        # No arg callable is used to instantiate
        # protocols in asyncio.
        return self._protocol_cls(*self._args, **self._kwargs)

Now instead of passing the RedisServerProtocol to the loop.create_server call, we can pass an instance of the protocol factory class we just created. Here's how everything looks once it's wired together:

factory = ProtocolFactory(
    RedisServerProtocol, PubSub()
)
coro = loop.create_server(factory, hostname, port)
server = loop.run_until_complete(coro)
print("Listening on port {}".format(port))
try:
    loop.run_forever()
except KeyboardInterrupt:
    print("Ctrl-C received, shutting down.")
finally:
    server.close()
    loop.run_until_complete(server.wait_closed())
    loop.close()
    print("Server shutdown.")
return 0

And that's all you need to get a basic PUBSUB implementation up and running using asyncio.

Wrapping Up

To summarize what we've done:

  • Create a new Pubsub class that gives you the ability to subscribe a transport to a channel name as well as the ability to publish a message to a channel.
  • Update the RedisServerProtocol class to accept a reference to this object in its __init__.
  • Update RedisServerProtocol.data_received to use this _pubsub instance whenever we received a PUBLISH or SUBSCRIBE command.
  • Create a protocol factory that passes the same shared PubSub object to every protocol instance that gets created.

In the next post, we'll look at how you can implement BLPOP/BRPOP with asyncio.

One last thing. I'm in the process of getting this code on github. I'll update this post with a link once the repo is available, or you can follow me on twitter where I'll also post a link.


Python 3.4 featured a brand new library that's been getting a lot of attention: asyncio. For numerous reasons, including the fact that the originator of the pep is Guido himself, the asyncio library is growing in popularity within the python community.

So, I'm thinking that it might be fun to try to use this new asyncio library to write redis in pure python.

The Pitch

I maintain a library, fakeredis, which is a testing library that emulates redis via the redis-py client API. It doesn't have any of the runtime guarantees that redis has (yet), but for the most part, it has the same functional behavior as redis. All of its state is kept in memory, and it does nothing for persisting state to disk. After all, it's meant as a testing library, to avoid having to spin up a real redis server during your python unit tests. Here's fakeredis in a nutshell:

>>> import fakeredis
>>> r = fakeredis.FakeStrictRedis()
>>> r.set('foo', 'bar')
True
>>> r.get('foo')
'bar'
>>> r.lpush('bar', 1)
1
>>> r.lpush('bar', 2)
2
>>> r.lrange('bar', 0, -1)
[2, 1]

Same semantics as the redis-py client, with all state stored in the memory of the process.

So the idea is to use asyncio to provide a server that accepts client connections that can parse the wire protocol for redis. It then figures out the corresponding calls to make into fakeredis, which provide all the functional redis semantics, and takes the return values from fakeredis and constructs the appropriate wire response.

If all goes well I should have something that any redis client can talk to, without knowing they're not actually talking to the real redis server. We will have created a slower, less memory efficient implementation of the redis server without the "required" features like "persistence" or "replication". It's redis, writen in python, using asyncio. Sounds like fun.

If nothing else, we'll learn a little more about asyncio in the process.

Setting Scope

Now first off, I plan for this to be a multipart series.

The scope for this post, part 1, is to get to the point where we can make redis calls for all its basic functionality, which includes the API calls for manipulating data for all of redis's supported types. Perhaps what's more interesting is what I'm leaving out in this post.

What I won't look at in this post is:

  • saving to disk
  • blocking operations, such as BLPOP
  • performance
  • handling slow clients
  • expirations
  • any kind of replication
  • testing

These items will be the subject of future posts. This is a long winded way of me saying that we're going to be taking shortcuts. It'll be ok.

Assumptions

To get the most out of this post, I'm assuming that:

  • You're familiar with redis from an end-user perspective. You know what redis is and you're familiar with the basic commands.
  • You're new to asyncio, but you're not necessarily new to event driven programming.
  • You're using python 3.4 or greater.

Get the Skeleton Up and Running

The very first thing I want to do is get something up and running. It doesn't have to do much, but I want to be able to at least have the server handle a request and return a response, even if it's hardcoded. I'm going to pick the GET command because it's the simplest operation that provides useful functionality. Once we get this running, we'll pick it apart and figure out how it actually works.

So first things first, let's hop over to the asyncio reference docs.

End to End Skeleton

Asyncio appears to have a huge amount of documentation, but most of it is stuff I don't care about right now. The closest thing that looks interesting is this TCP echo server protocol, which shows a basic echo server with asyncio. We should be able to start with the echo server and adapt that to what we want, at least initially. Here's what I came up with after trying to adapt the echo server example above to a hard coded redis GET command.

import asyncio


class RedisServerProtocol(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data):
        message = data.decode()
        if 'GET' in message:
            self.transport.write(b"$3\r\n")
            self.transport.write(b"BAZ\r\n")
        else:
            self.transport.write(b"-ERR unknown command\r\n")


def main(hostname='localhost', port=6379):
    loop = asyncio.get_event_loop()
    coro = loop.create_server(RedisServerProtocol,
                              hostname, port)
    server = loop.run_until_complete(coro)
    print("Listening on port {}".format(port))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print("User requested shutdown.")
    finally:
        server.close()
        loop.run_until_complete(server.wait_closed())
        loop.close()
        print("Redis is now ready to exit.")
    return 0


if __name__ == '__main__':
    main()

Save the above code to a file redis-asyncio and run it. We'll use the redis-cli to verify this has the behavior that we want:

$ ./redis-asyncio &
[1] 96221
Listening on port 6379

$ redis-cli
127.0.0.1:5678> GET foo
"BAZ"
127.0.0.1:5678> GET bar
"BAZ"
127.0.0.1:5678> GET anything
"BAZ"
127.0.0.1:5678> FOOBAR asdf
(error) ERR unknown command

It works!

But How Does it Work?

There's a lot we haven't explained yet.

While I'm going to skip over the get_event_loop and run_until_complete for now, the create_server is interesting. How exactly does this server we create integrate with the RedisServerProtocol we made? For example, how do we go from create_server to calling RedisServerProtocol.connection_made?

What helped me the most was just digging into the source code for asyncio, so let's do that. I've annotated and simplified the code to give you a high level view of what's going on. We'll start with create_server, and keeping going through the various methods until we see our protocol's connection_made method being called.

# These are all methods within an EventLoop class.

@coroutine
def create_server(self, protocol_factory, host=None, port=None,
                  *,
                  family=socket.AF_UNSPEC,
                  flags=socket.AI_PASSIVE,
                  sock=None,
                  backlog=100,
                  ssl=None,
                  reuse_address=None):
    # In this scenario the ``protocol_factory`` maps
    # to the ``RedisServerProtocol`` class object.

    # Create listening socket(s).
    socket = lots_of_code()

    server = Server(self, socket)
    # Once we create a server, we call _start_serving.
    # Note how we're passing along the protocol_factory
    # argument (our ``RedisServerProtocol`` class).
    self._start_serving(protocol_factory, socket, ssl, server)
    return server

def _start_serving(self, protocol_factory, sock,
                   sslcontext=None, server=None):
    # We're registering the _accept_connection method to be called
    # when a new connection is made.  Again notice how we're
    # still passing along our protocol_factory (``RedisServerProtocol``
    # class) object.
    self.add_reader(sock.fileno(), self._accept_connection,
                    protocol_factory, sock, sslcontext, server)

def _accept_connection(self, protocol_factory, sock,
                       sslcontext=None, server=None):
    # Finally!  We're we can see that we instantiate
    # the protocol_factory class to actually get
    # an instance of ``RedisServerProtocol``.
    # We've gone from a class to an instance.  So
    # what about connection_made?  When does this get
    # called?  Down the stack to _make_socket_transport.
    self._make_socket_transport(
        conn, protocol_factory(), extra={'peername': addr},
        server=server)

def _make_socket_transport(self, sock, protocol, waiter=None, *,
                           extra=None, server=None):
    # At least we can see we've gone from protocol_factory to
    # just protocol, so now "protocol" in this scenario is an
    # instance of ``RedisServerProtocol``.
    return _SelectorSocketTransport(self, sock, protocol, waiter,
                                    extra, server)

class _SelectorSocketTransport:
    def __init__(self, loop, sock, protocol, waiter=None,
                 extra=None, server=None):
        super().__init__(loop, sock, protocol, extra, server)
        self._eof = False
        self._paused = False

        self._loop.add_reader(self._sock_fd, self._read_ready)
        # And finally, we see that we ask the event loop to call
        # the connection_made method of our protocol class, and we're
        # passing "self" (The transport object) as an argument to
        # connection_made.
        self._loop.call_soon(self._protocol.connection_made, self)

Recap

So far, we've learned:

  • It looks like the interesting stuff we'll be writing is in the Protocol. To write our own redis server, we're going to flesh out a proper RedisServerProtocol class that understands the redis wire protocol.
  • We get 1 protocol per client connection. Storing state on the protocol will be scoped to the lifetime of that connection.
  • To wire things up, hand the protocol class to the create_server, which is called on an event loop instance. As we saw in the code snippet above in _accept_connection(), the protocol_factory argument is called with no args to create a protocol instance. While a class object works fine for now, we're going to have to use a closure or a factory class to pass arguments to the protocol when it's created.
  • The protocols themselves let you define methods that are invoked by the event loop. That is asyncio will call methods when there's a connection_made(), or there's data_received. Looking at the Protocol classes, there appears to be a few more methods you can implement.

Now that we understand the basics, we can start looking at the redis wire protocol.

Parsing the Wire Protocol

First thing we're going to need to do properly handle requests is protocol parser, this is the code that takes the redis request off the TCP socket and parses it into something meaningful. This code for this isn't that interesting. Reading the docs for the redis wire protocol, it's straightforward to implement.

Now, none of this is optimized yet, but here's a basic implementation of parsing the redis wire protocol. It accepts a byte string, and returns python objects.

def parse_wire_protocol(message):
    return _parse_wire_protocol(io.BytesIO(message))


def _parse_wire_protocol(msg_buffer):
    current_line = msg_buffer.readline()
    msg_type, remaining = chr(current_line[0]), current_line[1:]
    if msg_type == '+':
        return remaining.rstrip(b'\r\n').decode()
    elif msg_type == ':':
        return int(remaining)
    elif msg_type == '$':
        msg_length = int(remaining)
        if msg_length == -1:
            return None
        result = msg_buffer.read(msg_length)
        # There's a '\r\n' that comes after a bulk string
        # so we .readline() to move passed that crlf.
        msg_buffer.readline()
        return result
    elif msg_type == '*':
        array_length = int(remaining)
        return [_parse_wire_protocol(msg_buffer) for _ in range(array_length)]

We're also going to need the inverse of this, something that takes a response from fakeredis and converts it back into bytes that can be sent across the wire. Again, nothing too interesting about this code, but here's what I came up with:

def serialize_to_wire(value):
    if isinstance(value, str):
        return ('+%s' % value).encode() + b'\r\n'
    elif isinstance(value, bool) and value:
        return b"+OK\r\n"
    elif isinstance(value, int):
        return (':%s' % value).encode() + b'\r\n'
    elif isinstance(value, bytes):
        return (b'$' + str(len(value)).encode() +
                b'\r\n' + value + b'\r\n')
    elif value is None:
        return b'$-1\r\n'
    elif isinstance(value, list):
        base = b'*' + str(len(value)).encode() + b'\r\n'
        for item in value:
            base += serialize_to_wire(item)
        return base

Let's try this out:

>>> set_request = b'*3\r\n$3\r\nset\r\n$3\r\nfoo\r\n$3\r\nbar\r\n'
>>> parse_wire_protocol(set_request)
[b'set', b'foo', b'bar']

>>> serialize_to_wire([b'5', b'4', b'3', b'2', b'1'])
b'*5\r\n$1\r\n5\r\n$1\r\n4\r\n$1\r\n3\r\n$1\r\n2\r\n$1\r\n1\r\n'

After calling the parse_wire_protocol we can see that get a list of [command_name, arg1, arg2, ...].

Implementing the Protocol Class

We should have everything we need to make a more realistic RedisServerProtocol class now. We're making the assumption for now that the entire command is provided when data_received is called.

class RedisServerProtocol(asyncio.Protocol):


    def __init__(self, redis):
        self._redis = redis
        self.transport = None

    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data):
        parsed = parse_wire_protocol(data)
        # parsed is an array of [command, *args]
        command = parsed[0].decode().lower()
        try:
            method = getattr(self._redis, command)
        except AttributeError:
            self.transport.write(
                b"-ERR unknown command " + parsed[0] + b"\r\n")
            return
        result = method(*parsed[1:])
        serialized = serialize_to_wire(result)
        self.transport.write(serialized)


class WireRedisConverter(object):
    def __init__(self, redis):
        self._redis = redis

    def lrange(self, name, start, end):
        return self._redis.lrange(name, int(start), int(end))

    def hmset(self, name, *args):
        converted = {}
        iter_args = iter(list(args))
        for key, val in zip(iter_args, iter_args):
            converted[key] = val
        return self._redis.hmset(name, converted)

    def __getattr__(self, name):
        return getattr(self._redis, name)

The most important part here is the data_received method. Note that the first thing we do is take the bytes data we're given and immediately parse that into a python list using our parse_wire_protocol. The next thing we do is try to look for a corresponding method in the WireRedisConverter class based on the command we've been given. The WireRedisConverter class takes the parsed python list we receive from clients and maps that into the appropriate calls into fakeredis. For example:

HMSET myhash field1 "Hello"                           <- redis-cli
['hmset', 'myhash', 'field1', 'Hello']                <- parsed
WireRedisConverter.hmset('myhash', 'field1', 'Hello')
FakeRedis.hmset('myhash', {'field1': 'Hello'})

I've only shown a portion of WireRedisConverter, but there's enough to give you the basic idea of how a python list maps is then mapped to fakeredis calls.

And finally, we serialize the python response back to bytes using serialize_to_wire and write this value out to the transport we received from connection_made.

Wiring Up the Protocol Class

We'll also need to make a change to our main function, mostly in how we wire up the RedisServerProtocol:

def main(hostname='localhost', port=6379):
    loop = asyncio.get_event_loop()
    wrapped_redis = WireRedisConverter(fakeredis.FakeStrictRedis())

    bound_protocol = functools.partial(RedisServerProtocol,
                                       wrapped_redis)
    coro = loop.create_server(bound_protocol,
                              hostname, port)
    server = loop.run_until_complete(coro)
    print("Listening on port {}".format(port))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print("User requested shutdown.")
    finally:
        server.close()
        loop.run_until_complete(server.wait_closed())
        loop.close()
        print("Redis is now ready to exit.")
    return 0

The biggest difference here is that we're using functools.partial so that we can pass in our wrapped fakeredis instance to the RedisServerProtocol class whenever it's created. As we saw earlier, the protocol_factory is called with no args and is expected to return a protocol instance. While we could write a protocol factory class, we're using functools.partial because that's all we need for now.

Testing it Out

And finally, we should have something that vaguely resembles redis. Let's try it out:

$ ./redis-asyncio &
[1] 55470

$ redis-cli
127.0.0.1:6379> set foo bar
OK
127.0.0.1:6379> get foo
"bar"
127.0.0.1:6379> set foo baz
OK
127.0.0.1:6379> get foo
"baz"

127.0.0.1:6379> lpush abc 1
(integer) 1
127.0.0.1:6379> lpush abc 2
(integer) 2
127.0.0.1:6379> lpush abc 3
(integer) 3
127.0.0.1:6379> lrange abc 0 -1
1) "3"
2) "2"
3) "1"

127.0.0.1:6379> hmset myhash field1 "hello" field2 "world"
OK
127.0.0.1:6379> hget myhash field1
"hello"
127.0.0.1:6379> hget myhash field2
"world"

127.0.0.1:6379> sadd myset "Hello"
(integer) 1
127.0.0.1:6379> sadd myset "World"
(integer) 1
127.0.0.1:6379> sadd myset "World"
(integer) 0
127.0.0.1:6379> smembers myset
1) "Hello"
2) "World"

Let's even try talking to ./redis-asyncio using the redis-py module:

>>> import redis
>>> r = redis.Redis()
>>> r.set('foo', 'bar')
True
>>> r.get('foo')
b'bar'

>>> r.lpush('mylist', 1)
1
>>> r.lpush('mylist', 2)
2
>>> r.lrange('mylist', 0, -1)
[b'2', b'1']

>>> r.sadd('myset', 'hello')
1
>>> r.sadd('myset', 'world')
1
>>> r.sadd('myset', 'world')
0
>>> r.smembers('myset')
{b'world', b'hello'}

>>> r.hmset('myhash', {'a': 'b', 'c': 'd'})
True
>>> r.hget('myhash', 'a')
b'b'
>>> r.hget('myhash', 'c')
b'd'

Wrapping Up

In this post, we looked at getting a basic redis implementation up and running using asyncio and fakeredis. We were able to run basic commands such as get, set, lpush, lrange, sadd, smembers, hmset, hget, etc.

In the next post, we'll look at implementing blocking operations such as BLPOP.

And, because you're probably just as curious as I was, here's a benchmark comparison between what we've written and the real redis server. These benchmarks were run on the same machine so it's the relative difference that's interesting to me. I wouldn't read too much into it though.

redis-benchmark -t set -n 200000

redis-server redis-asyncio
====== SET ======
  200000 requests completed
    in 1.52 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

99.97% <= 1 milliseconds
100.00% <= 1 milliseconds
131926.12 requests per second
      
====== SET ======
  200000 requests completed
    in 5.17 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

0.25% <= 1 milliseconds
99.48% <= 2 milliseconds
99.96% <= 3 milliseconds
99.98% <= 4 milliseconds
99.99% <= 5 milliseconds
99.99% <= 6 milliseconds
99.99% <= 7 milliseconds
99.99% <= 8 milliseconds
99.99% <= 9 milliseconds
99.99% <= 10 milliseconds
99.99% <= 11 milliseconds
99.99% <= 12 milliseconds
99.99% <= 13 milliseconds
99.99% <= 14 milliseconds
99.99% <= 15 milliseconds
99.99% <= 16 milliseconds
99.99% <= 18 milliseconds
100.00% <= 19 milliseconds
100.00% <= 20 milliseconds
100.00% <= 21 milliseconds
100.00% <= 22 milliseconds
100.00% <= 24 milliseconds
100.00% <= 25 milliseconds
100.00% <= 26 milliseconds
100.00% <= 27 milliseconds
100.00% <= 29 milliseconds
100.00% <= 30 milliseconds
100.00% <= 32 milliseconds
38654.81 requests per second
      

redis-benchmark -t get -n 200000

redis-server redis-asyncio
====== GET ======
  200000 requests completed
    in 1.53 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

100.00% <= 0 milliseconds
130975.77 requests per second
      
====== GET ======
  200000 requests completed
    in 6.42 seconds
  50 parallel clients
  3 bytes payload
  keep alive: 1

0.18% <= 1 milliseconds
96.55% <= 2 milliseconds
99.83% <= 3 milliseconds
99.98% <= 4 milliseconds
99.99% <= 5 milliseconds
99.99% <= 6 milliseconds
99.99% <= 7 milliseconds
99.99% <= 8 milliseconds
99.99% <= 9 milliseconds
99.99% <= 10 milliseconds
99.99% <= 11 milliseconds
99.99% <= 12 milliseconds
99.99% <= 13 milliseconds
99.99% <= 14 milliseconds
99.99% <= 15 milliseconds
99.99% <= 17 milliseconds
99.99% <= 18 milliseconds
99.99% <= 19 milliseconds
99.99% <= 21 milliseconds
99.99% <= 22 milliseconds
99.99% <= 23 milliseconds
100.00% <= 24 milliseconds
100.00% <= 26 milliseconds
100.00% <= 27 milliseconds
100.00% <= 29 milliseconds
100.00% <= 31 milliseconds
100.00% <= 32 milliseconds
100.00% <= 34 milliseconds
100.00% <= 35 milliseconds
100.00% <= 37 milliseconds
100.00% <= 39 milliseconds
100.00% <= 41 milliseconds
31157.50 requests per second
      

I'm going to show you how a micro optimization can speed up your python code by a whopping 5%. 5%! It can also annoy anyone that has to maintain your code.

But really, this is about explaining code might you see occasionally see in the standard library or in other people's code. Let's take an example from the standard library, specifically the collections.OrderedDict class:

def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
    if key not in self:
        root = self.__root
        last = root[0]
        last[1] = root[0] = self.__map[key] = [last, root, key]
    return dict_setitem(self, key, value)

Notice the last arg: dict_setitem=dict.__setitem__. It makes sense if you think about it. To associate a key with a value, you'll need to provide a __setitem__ method which takes three arguments: the key you're setting, the value associated with the key, and the __setitem__ class method to the built in dict class. Wait. Ok maybe the last argument makes no sense.

Scope Lookups

To understand what's going on here, we need to take a look at scopes. Let's start with a simple question, if I'm in a python function, and I encounter something named open, how does python go about figuring out the value of open?

# <GLOBAL: bunch of code here>

def myfunc():
    # <LOCAL: bunch of code here>
    with open('foo.txt', 'w') as f:
        pass

The short answer is that without knowing the contents of the GLOBAL and the LOCAL section, you can't know for certain the value of open. Conceptually, python checks three namespaces for a name (ignoring nested scopes to keep things simple):

  • locals
  • globals
  • builtin

So in the myfunc function, if we're trying to find a value for open, we'll first check the local namespace, then the globals namespace, then the builtins namespace. And if open is not defined in any namespace, a NameError is raised.

Scope Lookups, the Implementation

The lookup process above is just conceptual. The implementation of this lookup process gives us room to exploit the implementation.

def foo():
    a = 1
    return a

def bar():
    return a

def baz(a=1):
    return a

Let's look at the bytecode of each function:

>>> import dis
>>> dis.dis(foo)
  2           0 LOAD_CONST               1 (1)
              3 STORE_FAST               0 (a)

  3           6 LOAD_FAST                0 (a)
              9 RETURN_VALUE


>>> dis.dis(bar)
  2           0 LOAD_GLOBAL              0 (a)
              3 RETURN_VALUE


>>> dis.dis(baz)
  2           0 LOAD_FAST                0 (a)
              3 RETURN_VALUE

Look at the differences between foo and bar. Right away we can see that at the bytecode level python has already determined what's a local variable and what is not because foo is using LOAD_FAST and bar is using LOAD_GLOBAL.

We won't get into the details of how python's compiler knows when to emit which bytecode (perhaps that's another post), but suffice to say python knows which type of lookup it needs to perform when it executes a function.

One other thing that can be confusing is that LOAD_GLOBAL is used for lookups in the global as well as the builtin namespace. You can think of this as "not local", again ignoring the issue of nested scopes. The C code for this is roughly [1]:

case LOAD_GLOBAL:
    v = PyObject_GetItem(f->f_globals, name);
    if (v == NULL) {
        v = PyObject_GetItem(f->f_builtins, name);
        if (v == NULL) {
            if (PyErr_ExceptionMatches(PyExc_KeyError))
                format_exc_check_arg(
                            PyExc_NameError,
                            NAME_ERROR_MSG, name);
            goto error;
        }
    }
    PUSH(v);

Even if you've never seen any of the C code for CPython, the above code is pretty straightforward. First, check if the key name we're looking for is in f->f_globals (the globals dict), then check if the name is in f->f_builtins (the builtins dict), and finally, raise a NameError if both checks failed.

Binding Constants to the Local Scope

Now when we look at the initial code sample, we can see that the last argument is binding a function into the local scope of a function. It does this by assigning a value, dict.__setitem__, as the default value of an argument. Here's another example:

def not_list_or_dict(value):
  return not (isinstance(value, dict) or isinstance(value, list))

def not_list_or_dict(value, _isinstance=isinstance, _dict=dict, _list=list):
  return not (_isinstance(value, _dict) or _isinstance(value, _list))

We're doing the same thing here, binding what would normally be objects that are in the builtin namespace into the local namespace instead. So instead of requiring the use of LOAD_GLOBAL (a global lookup), python instead will use LOCAL_FAST. So how much faster is this? Let's do some crude testing:

$ python -m timeit -s 'def not_list_or_dict(value): return not (isinstance(value, dict) or isinstance(value, list))' 'not_list_or_dict(50)'
1000000 loops, best of 3: 0.48 usec per loop
$ python -m timeit -s 'def not_list_or_dict(value, _isinstance=isinstance, _dict=dict, _list=list): return not (_isinstance(value, _dict) or _isinstance(value, _list))' 'not_list_or_dict(50)'
1000000 loops, best of 3: 0.423 usec per loop

Or in other words, that's an 11.9% improvement [2]. That's way more than the 5% I promised at the beginning of this post!

There's More to the Story

It's reasonable to think that the speed improvment is because LOAD_FAST reads from the local namespace whereas LOAD_GLOBAL will first check the global namespace before falling back to checking the builtin namespace. And in the example function above, isinstance, dict, and list all come from the built in namespace.

However, there's more going on. Not only are we able to skip additional lookup with LOAD_FAST, it's also a different type of lookup.

The C code snippet above showed the code for LOAD_GLOBAL, but here's the code for LOAD_FAST:

case LOAD_FAST:
    PyObject *value = fastlocal[oparg];
    if (value == NULL) {
        format_exc_check_arg(PyExc_UnboundLocalError,
                             UNBOUNDLOCAL_ERROR_MSG,
                             PyTuple_GetItem(co->co_varnames, oparg));
        goto error;
    }
    Py_INCREF(value);
    PUSH(value);
    FAST_DISPATCH();

We're retrieving the local value by indexing into an array. It's not shown here, but oparg is just an index into that array.

Now it's starting to make sense. In our first version not_list_or_dict had to perform 4 lookups, and each name was in the builtins namespace which we only look at after looking in the globals namespace. That's 8 dictionary key lookups. Compare that to directly indexing into a C array 4 times, which is what happens in the second version of not_list_or_dict, which all use LOAD_FAST under the hood. This is why lookups in the local namespace are faster.

Wrapping Up

Now the next time you see this in someone else's code you'll know what's going on.

And one final thing. Please don't actually do these kinds of optimizations unless you really need to. And most of the time you don't need to. But when the time really comes, and you really need to squeeze out every last bit of performance, you'll have this in your back pocket.

Footnotes

[1]Though keep in mind that I removed some performance optimizations in the above code to make it simpler to read. The real code is slightly more complicated.
[2]On a toy example for a function that doesn't really do anything interesting nor does it perform any IO and is mostly bound by the python VM loop.

JMESPath is an expression language that allows you to manipulate JSON. From selecting specific keys from a hash or only selecting keys based on certain filter criteria, JMESPath gives you a lot of power when working with JSON.

In my experience, the quickest way to get up to speed with a language is to try the language out. The JMESPath tutorial gives you a brief introduction to the language, but to really solidify the concepts you really just need to spend some time experimenting with the language.

You could accomplish this by using one of the existing JMESPath libraries, but there's an easier to way to accomplish this. You can use the JMESPath terminal. Either specify what JSON file to use or pipe the JSON document into the jmespath-terminal command.

The JMESPath terminal README has instructions on getting setup and how to use the JMESPath terminal.

Check it out, and feel free to leave any feedback and suggestions on the issue tracker.


I've just released 0.4.0 of semidbm. This represents a number of really cool features. See the full changelog for more details.

One of the biggest features is python 3 support. I was worried about not introducing a performance regression by supporting python 3. Fortunately, this was not the case.

In fact, performance increased. This was possible for a number of reasons. First, the index file and data file were combined into a single file. This means that a __setitem__ call results in only a single write() call. Also, semidbm now uses a binary format. This results in a more compact form and it's easier to create the sequence of bytes we need to write out to disk. This is also including the fact that semidbm now includes checksum data for each write that occurs.

Try it out for yourself.

What's Next?

I think at this time, semidbm has more than exceeded it's original goal, which was to be a pure python cross platform key value storage that had reasonable performance. So what's next for semidbm? In a nutshell, higher level abstractions (aka the "fun stuff"). Code that builds on the simple key value storage of semidbm.db and provides additional features. And as we get higher level, I think it makes sense to reevaluate the original goals of semidbm and whether or not it makes sense to carry those goals forward:

  • Cross platform. I'm inclined to not support windows for these higher level abstractions.
  • Pure python. I think the big reason for remaining pure python was for ease of installation. Especially on windows, pip installing a package should just work. With C extensions, this becomes much harder on windows. If semidbm isn't going to support windows for these higher level abstractions, then C extensions are fair game.

Some ideas I've been considering:

  • A C version of _Semidbm.
  • A dict like interface that is concurrent (possibly single writer multiple reader).
  • A sorted version of semidbm (supporting things like range queries).
  • Caching reads (need an efficient LRU cache).
  • Automatic background compaction of data file.
  • Batched writes
  • Transactions
  • Compression (I played around with this earlier. Zlib turned out to be too slow for the smaller sized values (~100 bytes) but it might be worth being able to configure this on a per db basis.