How to call Lambda functions in batch from Redshift

Harry Glaser
Download This Notebook

Often, the best way to deploy machine learning models to warehouses is with lambda functions. The lambda can be customized to match your Python development environment. If you use Docker, the lambda can even be scaled quite large to accommodate your memory-hungry language or image model.

With Redshift, you don’t even need to slap an API gateway on that lambda function before calling it from your warehouse. And if data science and warehousing are organizationally separate, no problem. You can call a lambda function that’s in another AWS account.

The problem: Batch inference efficiency 

There’s a hitch: Most of the time, when calling ML models from warehouses, we’re doing batch inference. But when you call your model in batch using an UPDATE statement, Redshift calls your model one row at a time! Not exactly efficient.

To see how bad it gets ,let’s take a look at our trusty NBA points predictor. This model is a simple linear regression that predicts the number of points an NBA player will score. It’s loaded into a lambda function that’s called by the Redshift external function “ext_predict_points_latest”.

Finally, because this can get quite slow, we’ve got a “slim_nba_game_details” table that’s only 500 rows.

Let’s try calling our lambda function in a simple UPDATE statement that asks for 500 inferences in batch:

{%CODE sql%}
update slim_nba_game_details set predicted_points = ext_predict_points_latest(FGM);

This takes a whopping 3m46.9s! About 450ms seconds per inference. If we want to scale this up, this architecture will not fly.

The solution: Inference in a SELECT

The problem is that Redshift doesn’t like to batch calls inside an UDPATE statement. It’s happy to do so, though, inside a SELECT statement.

So how to proceed? Simple! Ask for all your predictions first, in a SELECT statement. To capture them, we’ll use a temporary table. Let’s begin by creating that table.

{%CODE sql%}
create temp table predictions_tbl (uuid varchar, predicted_points numeric);

This took a near-instantaneous 19ms. We do need a join key into our main table. Our table happens to use UUID’s, so we’ve got one here. Integer primary keys are just as good.

Now let’s get our inferences in batch, inserted into the temp table:

{%CODE sql%}
insert into predictions_tbl (uuid, predicted_points)
select uuid, ext_predict_points_latest(FGM)as predictions
from slim_nba_game_details;

Note that we’re selecting the UUID’s out of the slim_nba_game_details table at inference time, so we can use them to join later!

Importantly, this took only 5.1s! 10ms per inference, or about 45X faster than the UPDATE call! If you’re doing large batches, this is well worth the extra complexity.

But of course, now that we’ve captured the inferences in a temp table, we need to update our main table with the results. Here’s the SQL:

{%CODE sql%}
update slim_nba_game_details
set predicted_points =predictions_tbl.predicted_points
from predictions_tb;
join slim_nba_game_detailst on predictions_tbl.uuid = t.uuid;

This does add an additional 9.4s to the operation, making it a total of 14.5s. So, in total, over 15X faster than the UPDATE statement. Not bad for a little SQL surgery.

Redshift, Lambda and Machine Learning

We won’t lie: All of this is slicker in Snowflake, and the efficiency is just as good in Snowflake without the syntax gymnastics. But if you’re a Redshift shop, or if you’re still choosing warehouses and you've got AWS credits to burn, Redshift can be a strong choice for batch machine learning with lambda functions. You just need to know how to work with it.

Happy modeling! We hope you’ll consider Modelbit if you’re deploying lots of ML models and need to call them from multiple endpoints.

Download This Notebook

Learn more

Sign up to get news and updates from Modelbit
You have been added to our list!
Something went wrong!