SQLAlchemy and Race Conditions: Implementing get_one_or_create()

Note: Examples here are built with respect to Flask SQLAlchemy, and while some notation may match that convention, the concepts should apply to use of SQLAlchemy in general.

Motivation

Suppose we have a Flask app that interacts with an API on a site that hosts webgames. Our users have OAuth'ed our application into the API so we can see their activity, and we want to track which games they have beat. (We'll assume that users only beat a game once or more generally we are only concerned with the first time they beat it.) We create the Game and GameBeat models like so:

class Game(db.Model):
__tablename__  = 'games'
id = db.Column(db.Integer, primary_key=True)
provider_game_id = db.Column(db.String(255))
provider_game_name = db.Column(db.String(255))
provider_game_category = db.Column(db.String(255))

class GameBeat(db.Model):
__tablename__ = 'game_beats'
id = db.Column(db.Integer, primary_key=True)
game_id = db.Column(db.Integer, db.ForeignKey('games.id'))
game = db.relationship("Game",
backref="beats",
lazy="joined",
innerjoin=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'))
user = db.relationship("User",
backref="game_beats",
lazy="joined",
innerjoin=True)
beat_at = db.Column(db.DateTime())


A few notes on these models. This assumes that each Game object coming from the API has an id, name, and category attribute, and anytime a game is beat by a user, you can get the datetime of when it was beat. One other convention is using db.String(255) for provider_game_id (rather than db.Integer). You will often see this with APIs where the ids can be very large integers (think of the number of Tweets on Twitter for example), and while Python handles big ints well, unless you want to use big ints for actually doing math, a string will be much more efficient (and you'll be less likely to have an issue with your backend datastore).

Implementing a basic get_one_or_create()

NOTE: The original implementation I used was first inspired from this Stack Overflow question.

If you've ever played around with Django, you've probably seen the get_or_create() function, but if not the concept is straight forward. First look for an object given a set of constraints, and if it doesn't exist, create it. The usefulness here is pretty obvious in our above example. Without it, given a list of game_beats from the API and a user, we'd have to do something like

for game_beat_data in game_beat_data_list:
if db.session.query(Game).filter(Game.provider_game_id == game_beat_data['game']['game_id']).count() == 0:
game = Game(provider_game_id = game_beat_data['game']['game_id'],
provider_game_name = game_beat_data['game']['game_name'],
provider_game_category = game_beat_data['game']['game_category']
)
else:
game = db.session.query(Game).filter(Game.provider_game_id == game_beat_data['game']['game_id']).one()
if db.session.query(GameBeat).filter(GameBeat.game == game).filter(GameBeat.user == user).count() == 0:
game_beat = GameBeat(game = game,
user = user,
beat_at = game_beat_data['beat_at']
)
else:
game_beat = db.session.query(GameBeat).filter(GameBeat.game == game).filter(GameBeat.user == user).one()
db.session.commit()


This is quite verbose, and the logic is repeated twice, so it seems like a perfect place to write a function. Here's a first attempt, (and we'll build from here):

from sqlalchemy.orm.exc import NoResultFound

def get_one_or_create(session,
model,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one()
except NoResultFound:
return model(**kwargs)


There are a couple small differences here. Most obvious is the change to a try/except block rather than getting the count() explicitly. The advantage here is that if the object exists, we only have to make one call to the datastore. The other is the use of filter_by(), which is just a version of filter that uses keyword arguments.

Using this function, our above example now becomes:

for game_beat_data in game_beat_data_list:
game = get_one_or_create(
db.session,
Game,
provider_game_id = game_beat_data['game']['game_id']
)
game_beat = get_one_or_create(
db.session,
GameBeat,
game = game,
user = user,
beat_at = game_beat_data['beat_at']
)
db.session.commit()


Now, since I've broken the inputs into our function onto multiple lines, we've only dropped from 17 lines to 15, however I'd still argue this is a HUGE decrease in complexity. We can now actually fit in under 80 characters without having to use \ for newlines, but more generally, it's simply more readable.

A small modification

I'm not going to motivate this, but sometimes you'll want to know if the object you're getting back was newly created or pulled from the datastore, so we can easily return a bool as well by updating our function to

from sqlalchemy.orm.exc import NoResultFound

def get_one_or_create(session,
model,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), True
except NoResultFound:
return model(**kwargs), False


and now we just have to remember that when we use it, to unpack into two variables

obj, exists = get_one_or_create( … )


or else we will end up with a tuple where we expected obj.

The @classmethod decorator

One of my favorite Python patterns is using the @classmethod decorator with a creator function. In our above example, suppose we want to be able to create a Game object, but we may or may not have the provider_game_name and provider_game_category immediately available. We can update our model definition (assuming we have a get_game_data_from_api function from our API):

class Game(db.Model):
__tablename__  = 'games'
id = db.Column(db.Integer, primary_key=True)
provider_game_id = db.Column(db.String(255))
provider_game_name = db.Column(db.String(255))
provider_game_category = db.Column(db.String(255))

@classmethod
create_from_provider_game_id(cls, provider_game_id,**kwargs):
provider_game_name = kwargs.get('provider_game_name', None)
provider_game_category = kwargs.get('provider_game_category', None)
if not provider_game_name and provider_game_category:
game_data = get_game_data_from_api(provider_game_id)
provider_game_name = game_data['provider_game_name']
provider_game_category = game_data['provider_game_category']
return cls(provider_game_id = provider_game_id,
provider_game_name = provider_game_name,
provider_game_category = provider_game_category)


now if we do

game = Game.create_from_provider_game_id(provider_game_id = game_beat_data['game']['id'])


we get a new Game object which will ping the API in order to get the other data to fill it out. But, if we already have the data (or we maybe have the data), we can do the following:

game = Game.create_from_provider_game_id(
provider_game_id = game_beat_data['game']['id'],
provider_game_name = game_beat_data['game'].get('name', None),
provider_game_category = game_beat_data['game'].get('category', None)
)


I like to think of writing code like this as being fault tolerant. If you've played with external APIs before, their behavior can often change slightly. Moreover, writing code like this gives me a consistent pattern, so if in one case I have a collection of dicts from an API with all the relevant game data, and in another case I only have a collection of the provider_game_ids, I can use the same pattern. If I have all the data, I can save myself n calls to the API, and if I'm not sure (or maybe the API is inconsistent - yes it happens), it will still work.

Updating get_one_or_create

It's quickly obvious that our current implementation of get_one_or_create() won't handle such pattern for creating model instances. But this is programming, we are the creators. Let's update our function:

from sqlalchemy.orm.exc import NoResultFound

def get_one_or_create(session,
model,
create_method='',
create_method_kwargs=None,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), True
except NoResultFound:
kwargs.update(create_method_kwargs or {})
return getattr(model, create_method, model)(**kwargs), False


The reason for adding this create_method_kwargs dict is fairly straight forward: consider a case where we need to hand a user to our API method get_game_data_from_api, but we don't store a user on the Game object and certainly don't want to include it in the filter_by.

At this point, we have a pretty robust get_one_or_create() function that worked quite well for me, until it didn't.

An Unexpected Exception

By using my newly minted get_one_or_create() wherever I create objects, you can imagine my surprise when MultipleResultsFound exceptions starting being raised when trying to get_one_or_create() a Game. I scoured my code for any place that I created a Game; luckily grep makes this easy. Nothing. I expanded my test so that I was mocking a scan for 100 users. Nothing. Maybe there was a unicode issue where the get part was trying to match to a unicode to string (and failing) but the create part was then creating the attribute as a string. A new test quickly disproved this. Nothing.

I was lost. I was writing tests hoping that they would fail just so I could figure out what was going wrong. And it was 4am. I hate going to sleep at a point like this. (I use sleep loosely here - I generally don't sleep well and lay awake trying to figure out how to fix my problem.) I'm not sure if it was that night or the next morning (or even when you draw the line going to bed at 4am for that matter…), but I eventually realized why I was getting duplicate objects in my datastore.

During the time between the get call when the object doesn't exist and the call to db.session.commit(), the object was created somewhere else. I was using Celery to batch the update and it was a brand new scan, so a lot of stuff was being created. Since the db.session.commit() comes at the end of the loop, there can actually be a fair amount of time between these calls. But even if a db.session.commit() happened sooner, we want to be sure that another matching object didn't get created between the get and create parts of a function.

Celery "async" and Race Conditions

@celery.task()

I use quotes here in the way I'd use air quotes if we were talking: the Celery async task model isn't true async in the way that Twisted or asyncio is. Instead it uses an external messaging system (I use Redis - mostly due to the price on Heroku). Functions and their inputs (remember that in Python, functions, like everything else, are objects) are JSONified, stored in the message system, and then deJSONified and executed in a completely separate process.

It does become more complicated if you care about the result of a function call, but often you call functions only for their side effects. Sometimes you'll want to kick off a job of a few such function calls as a result of a web request. If the job takes awhile (anything more than a second) you certainly do not wish to wait (or can't, eventually you'll timeout) for that function to complete before responding. Celery makes this pattern straight forward, if not easy, to accomplish with a synchronous web framework like Flask.

Celery also makes it really easy to horizontally scale a batch of jobs. Suppose we want to update the GameBeats for all our users. We could create a celery task like so

from app import celery, db
from app.models import User
from app.provider_apis import get_provider_api

def update_game_beats_for_user_id(user_id):
user = db.session.query(User).get(user_id)
api = get_provider_api(user)
game_beat_data_list = api.get_game_beats()
for game_beat_data in game_beat_data_list:
…


where the … completes the pattern above showing the usage of get_one_or_create(). We can now create a task for each user by running

from app.tasks import update_game_beats_for_user_id
from app import db
from app.models import User

users = db.session.query(User).query.all()
for user in users
update_game_beats_for_user_id.delay(user.id)


This makes it possible to spin up a number of Celery workers (Heroku for this is as easy as scrolling a slider on a web interface and clicking "Apply", or a single short bash command) and get all these tasks done quickly.

Race Conditions

Whenever you have multiple concurrent processes interacting with an external service, you have to be concerned about race conditions, e.g. from the perspective of any process, the state of that external service may change at anytime without any interaction from that process. Luckily, Celery is designed to handle this and maps jobs out in a way that no two processes will receive the same message.

SQLAlchemy even has some protection against this, although I'm not sure exactly how it's implemented. I do know that if I have an ipython shell open with a transition pending, a query in another process will stall. This goes out the window, however, when we're working with Celery workers on separate Heroku instances or even separate web instances. So we certainly want to protect against this type of thing in production.

In our example, since we plan to look up Game objects by provider_game_id and expect only one instance, we should have defined that column with a unique constraint like so:

class Game(db.Model):
…
provider_game_id = db.Column(db.String(255), unique=True)
…


Luckily we can make this change after the fact (which I won't show here, but I plan to write about next). Had we done this originally, we would never have gotten the MultipleResultsFound exception, but instead would have gotten a IntegrityError when trying to execute db.session.commit() of an transition that would create our duplicated item. This is desirable since the exception comes before the duplicate is introduced into the datastore. However, it would be best for our get_one_or_create function to take care of this as well.

A "final" get_one_or_create()

Again, think air quotes. As a developer, I never think or anything as final. Or maybe I always think an implementation is final, and I'm always wrong. Regardless, of the concerns I've reached so far, here is a "final" version:

def get_one_or_create(session,
model,
create_method='',
create_method_kwargs=None,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), True
except NoResultFound:
kwargs.update(create_method_kwargs or {})
created = getattr(model, create_method, model)(**kwargs)
try:
session.commit()
return created, False
except IntegrityError:
session.rollback()
return session.query(model).filter_by(**kwargs).one(), True


Here's what's happening. It tries to find the right object, and if it finds it, it returns it. If not, it creates it, adds it to the session, and attempts to commit. In the mean time, if another process has created and committed an object with the same details we filtered by, an IntegrityError raises and is caught. This means we should be able to get the newly created object, so we now get that and return it.

You'll notice one major difference from our earlier versions: session.commit() is now in the function. I've avoided doing this thus far due to the recommended usage in the SQLAlchemy documentations on sessions:

As a general rule, keep the lifecycle of the session separate and external from functions and objects that access and/or manipulate database data.

The reason we must add this into our function call here, is that if we do the db.session.commit() after making several calls to get_one_or_create() and the IntegrityError was raised, we'd have no idea which created object caused it and we'd have to start the entire transaction over (and we'd have to write logic outside the function to handle all this). Since our get_one_or_create() function is written generally for any model and unique keywords, rather than specific logic tied to a specific model, I don't think it seriously conflicts with the SQLAlchemy philosophy.

UPDATE: I have recanted here and decided that using session.flush() is superior to session.commit() for line 13 of the function. See my newer blog post for a detailed explanation of why.

UPDATE 2: I've updated using a {} as a default value in the function as this is typical Python gotcha. Thanks for the comment, Nigel! If your curious about this gotcha, check out this StackOverflow question and this blog post.

Final Thoughts

Race conditions and asynchronous programing are difficult, especially when working in a framework that doesn't force you to work or think that way. Flask and certainly Flask-SQLAlchemy aren't extremely well suited to handle this, but also with Flask's synchronous handling of requests, it doesn't become a issue often. When scaling, however, you being to increase the probability of such occurrences happen.

How does this work in Tornado or Twisted? My friend and former colleague Brian Muller wrote Twistar, a Python implementation of a nonblocking active record pattern interface to relational databases (built to use with Twisted). It's been on my list of packages to read. To some extent, building a webapp in Twisted has a high upfront cost, but it also has this added value of forcing you to think deeper and see these problems before you have a few hundred duplicated entries in your database.

I love Flask and the community, and it has all been very good to me as being quite new to web app development, but part of me wants to start playing around with the JS template systems like React. This also makes sense as you start getting into the realm of iOS and Android apps. At some point you have many different "front ends" and your backend is really just serving up data to fill it. At that point, a centralized template system like Jinja makes less sense, and moving from Flask to Twisted makes more sense.