View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |

Retrieval models are often built to surface a handful of top candidates out of millions or even hundreds of millions of candidates. To be able to react to the user's context and behaviour, they need to be able to do this on the fly, in a matter of milliseconds.

Approximate nearest neighbour search (ANN) is the technology that makes this possible. In this tutorial, we'll show how to use ScaNN - a state of the art nearest neighbour retrieval package - to seamlessly scale TFRS retrieval to millions of items.

## What is ScaNN?

ScaNN is a library from Google Research that performs dense vector similarity search at large scale. Given a database of candidate embeddings, ScaNN indexes these embeddings in a manner that allows them to be rapidly searched at inference time. ScaNN uses state of the art vector compression techniques and carefully implemented algorithms to achieve the best speed-accuracy tradeoff. It can greatly outperform brute force search while sacrificing little in terms of accuracy.

## Building a ScaNN-powered model

To try out ScaNN in TFRS, we'll build a simple MovieLens retrieval model, just as we did in the basic retrieval tutorial. If you have followed that tutorial, this section will be familiar and can safely be skipped.

To start, install TFRS and TensorFlow Datasets:

`pip install -q tensorflow-recommenders`

`pip install -q --upgrade tensorflow-datasets`

We also need to install `scann`

: it's an optional dependency of TFRS, and so needs to be installed separately.

`pip install -q scann`

Set up all the necessary imports.

```
from typing import Dict, Text
import os
import pprint
import tempfile
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
```

```
import tensorflow_recommenders as tfrs
```

And load the data:

```
# Load the MovieLens 100K data.
ratings = tfds.load(
"movielens/100k-ratings",
split="train"
)
# Get the ratings data.
ratings = (ratings
# Retain only the fields we need.
.map(lambda x: {"user_id": x["user_id"], "movie_title": x["movie_title"]})
# Cache for efficiency.
.cache(tempfile.NamedTemporaryFile().name)
)
# Get the movies data.
movies = tfds.load("movielens/100k-movies", split="train")
movies = (movies
# Retain only the fields we need.
.map(lambda x: x["movie_title"])
# Cache for efficiency.
.cache(tempfile.NamedTemporaryFile().name))
```

2021-10-02 11:53:59.413405: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Before we can build a model, we need to set up the user and movie vocabularies:

```
user_ids = ratings.map(lambda x: x["user_id"])
unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(user_ids.batch(1000))))
```

2021-10-02 11:54:00.296290: W tensorflow/core/kernels/data/cache_dataset_ops.cc:233] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 2021-10-02 11:54:04.003150: W tensorflow/core/kernels/data/cache_dataset_ops.cc:233] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

We'll also set up the training and test sets:

```
tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)
train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)
```

### Model definition

Just as in the basic retrieval tutorial, we build a simple two-tower model.

```
class MovielensModel(tfrs.Model):
def __init__(self):
super().__init__()
embedding_dimension = 32
# Set up a model for representing movies.
self.movie_model = tf.keras.Sequential([
tf.keras.layers.StringLookup(
vocabulary=unique_movie_titles, mask_token=None),
# We add an additional embedding to account for unknown tokens.
tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
])
# Set up a model for representing users.
self.user_model = tf.keras.Sequential([
tf.keras.layers.StringLookup(
vocabulary=unique_user_ids, mask_token=None),
# We add an additional embedding to account for unknown tokens.
tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
])
# Set up a task to optimize the model and compute metrics.
self.task = tfrs.tasks.Retrieval(
metrics=tfrs.metrics.FactorizedTopK(
candidates=movies.batch(128).cache().map(self.movie_model)
)
)
def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
# We pick out the user features and pass them into the user model.
user_embeddings = self.user_model(features["user_id"])
# And pick out the movie features and pass them into the movie model,
# getting embeddings back.
positive_movie_embeddings = self.movie_model(features["movie_title"])
# The task computes the loss and the metrics.
return self.task(user_embeddings, positive_movie_embeddings, compute_metrics=not training)
```

### Fitting and evaluation

A TFRS model is just a Keras model. We can compile it:

```
model = MovielensModel()
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))
```

Estimate it:

```
model.fit(train.batch(8192), epochs=3)
```

Epoch 1/3 10/10 [==============================] - 3s 223ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 69808.9716 - regularization_loss: 0.0000e+00 - total_loss: 69808.9716 Epoch 2/3 10/10 [==============================] - 3s 222ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 67485.8842 - regularization_loss: 0.0000e+00 - total_loss: 67485.8842 Epoch 3/3 10/10 [==============================] - 3s 220ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 66311.9581 - regularization_loss: 0.0000e+00 - total_loss: 66311.9581 <keras.callbacks.History at 0x7fc02423c150>

And evaluate it.

```
model.evaluate(test.batch(8192), return_dict=True)
```

3/3 [==============================] - 2s 246ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0011 - factorized_top_k/top_5_categorical_accuracy: 0.0095 - factorized_top_k/top_10_categorical_accuracy: 0.0222 - factorized_top_k/top_50_categorical_accuracy: 0.1261 - factorized_top_k/top_100_categorical_accuracy: 0.2363 - loss: 49466.8789 - regularization_loss: 0.0000e+00 - total_loss: 49466.8789 {'factorized_top_k/top_1_categorical_accuracy': 0.0010999999940395355, 'factorized_top_k/top_5_categorical_accuracy': 0.009549999609589577, 'factorized_top_k/top_10_categorical_accuracy': 0.022199999541044235, 'factorized_top_k/top_50_categorical_accuracy': 0.1261499971151352, 'factorized_top_k/top_100_categorical_accuracy': 0.23634999990463257, 'loss': 28242.8359375, 'regularization_loss': 0, 'total_loss': 28242.8359375}

## Approximate prediction

The most straightforward way of retrieving top candidates in response to a query is to do it via brute force: compute user-movie scores for all possible movies, sort them, and pick a couple of top recommendations.

In TFRS, this is accomplished via the `BruteForce`

layer:

```
brute_force = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
brute_force.index_from_dataset(
movies.batch(128).map(lambda title: (title, model.movie_model(title)))
)
```

<tensorflow_recommenders.layers.factorized_top_k.BruteForce at 0x7fbfc1d4fe10>

Once created and populated with candidates (via the `index`

method), we can call it to get predictions out:

```
# Get predictions for user 42.
_, titles = brute_force(np.array(["42"]), k=3)
print(f"Top recommendations: {titles[0]}")
```

Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)' b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

On a small dataset of under 1000 movies, this is very fast:

```
%timeit _, titles = brute_force(np.array(["42"]), k=3)
```

983 µs ± 5.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

But what happens if we have more candidates - millions instead of thousands?

We can simulate this by indexing all of our movies multiple times:

```
# Construct a dataset of movies that's 1,000 times larger. We
# do this by adding several million dummy movie titles to the dataset.
lots_of_movies = tf.data.Dataset.concatenate(
movies.batch(4096),
movies.batch(4096).repeat(1_000).map(lambda x: tf.zeros_like(x))
)
# We also add lots of dummy embeddings by randomly perturbing
# the estimated embeddings for real movies.
lots_of_movies_embeddings = tf.data.Dataset.concatenate(
movies.batch(4096).map(model.movie_model),
movies.batch(4096).repeat(1_000)
.map(lambda x: model.movie_model(x))
.map(lambda x: x * tf.random.uniform(tf.shape(x)))
)
```

We can build a `BruteForce`

index on this larger dataset:

```
brute_force_lots = tfrs.layers.factorized_top_k.BruteForce()
brute_force_lots.index_from_dataset(
tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))
)
```

<tensorflow_recommenders.layers.factorized_top_k.BruteForce at 0x7fbfc1d80610>

The recommendations are still the same

```
_, titles = brute_force_lots(model.user_model(np.array(["42"])), k=3)
print(f"Top recommendations: {titles[0]}")
```

Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)' b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

But they take much longer. With a candidate set of 1 million movies, brute force prediction becomes quite slow:

```
%timeit _, titles = brute_force_lots(model.user_model(np.array(["42"])), k=3)
```

33 ms ± 245 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

As the number of candidate grows, the amount of time needed grows linearly: with 10 million candidates, serving top candidates would take 250 milliseconds. This is clearly too slow for a live service.

This is where approximate mechanisms come in.

Using ScaNN in TFRS is accomplished via the `tfrs.layers.factorized_top_k.ScaNN`

layer. It follow the same interface as the other top k layers:

```
scann = tfrs.layers.factorized_top_k.ScaNN(num_reordering_candidates=100)
scann.index_from_dataset(
tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))
)
```

<tensorflow_recommenders.layers.factorized_top_k.ScaNN at 0x7fbfc2571990>

The recommendations are (approximately!) the same

```
_, titles = scann(model.user_model(np.array(["42"])), k=3)
print(f"Top recommendations: {titles[0]}")
```

Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)' b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

But they are much, much faster to compute:

```
%timeit _, titles = scann(model.user_model(np.array(["42"])), k=3)
```

4.35 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In this case, we can retrieve the top 3 movies out of a set of ~1 million in around 2 milliseconds: 15 times faster than by computing the best candidates via brute force. The advantage of approximate methods grows even larger for larger datasets.

## Evaluating the approximation

When using approximate top K retrieval mechanisms (such as ScaNN), speed of retrieval often comes at the expense of accuracy. To understand this trade-off, it's important to measure the model's evaluation metrics when using ScaNN, and to compare them with the baseline.

Fortunately, TFRS makes this easy. We simply override the metrics on the retrieval task with metrics using ScaNN, re-compile the model, and run evaluation.

To make the comparison, let's first run baseline results. We still need to override our metrics to make sure they are using the enlarged candidate set rather than the original set of movies:

```
# Override the existing streaming candidate source.
model.task.factorized_metrics = tfrs.metrics.FactorizedTopK(
candidates=lots_of_movies_embeddings
)
# Need to recompile the model for the changes to take effect.
model.compile()
%time baseline_result = model.evaluate(test.batch(8192), return_dict=True, verbose=False)
```

CPU times: user 22min 5s, sys: 2min 7s, total: 24min 12s Wall time: 51.9 s

We can do the same using ScaNN:

```
model.task.factorized_metrics = tfrs.metrics.FactorizedTopK(
candidates=scann
)
model.compile()
# We can use a much bigger batch size here because ScaNN evaluation
# is more memory efficient.
%time scann_result = model.evaluate(test.batch(8192), return_dict=True, verbose=False)
```

CPU times: user 10.5 s, sys: 3.26 s, total: 13.7 s Wall time: 1.85 s

ScaNN based evaluation is much, much quicker: it's over ten times faster! This advantage is going to grow even larger for bigger datasets, and so for large datasets it may be prudent to always run ScaNN-based evaluation to improve model development velocity.

But how about the results? Fortunately, in this case the results are almost the same:

```
print(f"Brute force top-100 accuracy: {baseline_result['factorized_top_k/top_100_categorical_accuracy']:.2f}")
print(f"ScaNN top-100 accuracy: {scann_result['factorized_top_k/top_100_categorical_accuracy']:.2f}")
```

Brute force top-100 accuracy: 0.15 ScaNN top-100 accuracy: 0.27

This suggests that on this artificial datase, there is little loss from the approximation. In general, all approximate methods exhibit speed-accuracy tradeoffs. To understand this in more depth you can check out Erik Bernhardsson's ANN benchmarks.

## Deploying the approximate model

The `ScaNN`

-based model is fully integrated into TensorFlow models, and serving it is as easy as serving any other TensorFlow model.

We can save it as a `SavedModel`

object

```
lots_of_movies_embeddings
```

<ConcatenateDataset shapes: (None, 32), types: tf.float32>

```
# We re-index the ScaNN layer to include the user embeddings in the same model.
# This way we can give the saved model raw features and get valid predictions
# back.
scann = tfrs.layers.factorized_top_k.ScaNN(model.user_model, num_reordering_candidates=1000)
scann.index_from_dataset(
tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))
)
# Need to call it to set the shapes.
_ = scann(np.array(["42"]))
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "model")
tf.saved_model.save(
scann,
path,
options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"])
)
loaded = tf.saved_model.load(path)
```

2021-10-02 11:55:53.875291: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:absl:Found untraced functions such as query_with_exclusions while saving (showing 1 of 1). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/tmpm0piq8hx/model/assets INFO:tensorflow:Assets written to: /tmp/tmpm0piq8hx/model/assets

and then load it and serve, getting exactly the same results back:

```
_, titles = loaded(tf.constant(["42"]))
print(f"Top recommendations: {titles[0][:3]}")
```

The resulting model can be served in any Python service that has TensorFlow and ScaNN installed.

It can also be served using a customized version of TensorFlow Serving, available as a Docker container on Docker Hub. You can also build the image yourself from the Dockerfile.

## Tuning ScaNN

Now let's look into tuning our ScaNN layer to get a better performance/accuracy tradeoff. In order to do this effectively, we first need to measure our baseline performance and accuracy.

From above, we already have a measurement of our model's latency for processing a single (non-batched) query (although note that a fair amount of this latency is from non-ScaNN components of the model).

Now we need to investigate ScaNN's accuracy, which we measure through recall. A recall@k of x% means that if we use brute force to retrieve the true top k neighbors, and compare those results to using ScaNN to also retrieve the top k neighbors, x% of ScaNN's results are in the true brute force results. Let's compute the recall for the current ScaNN searcher.

First, we need to generate the brute force, ground truth top-k:

```
# Process queries in groups of 1000; processing them all at once with brute force
# may lead to out-of-memory errors, because processing a batch of q queries against
# a size-n dataset takes O(nq) space with brute force.
titles_ground_truth = tf.concat([
brute_force_lots(queries, k=10)[1] for queries in
test.batch(1000).map(lambda x: model.user_model(x["user_id"]))
], axis=0)
```

Our variable `titles_ground_truth`

now contains the top-10 movie recommendations returned by brute-force retrieval. Now we can compute the same recommendations when using ScaNN:

```
# Get all user_id's as a 1d tensor of strings
test_flat = np.concatenate(list(test.map(lambda x: x["user_id"]).batch(1000).as_numpy_iterator()), axis=0)
# ScaNN is much more memory efficient and has no problem processing the whole
# batch of 20000 queries at once.
_, titles = scann(test_flat, k=10)
```

Next, we define our function that computes recall. For each query, it counts how many results are in the intersection of the brute force and the ScaNN results and divides this by the number of brute force results. The average of this quantity over all queries is our recall.

```
def compute_recall(ground_truth, approx_results):
return np.mean([
len(np.intersect1d(truth, approx)) / len(truth)
for truth, approx in zip(ground_truth, approx_results)
])
```

This gives us baseline recall@10 with the current ScaNN config:

```
print(f"Recall: {compute_recall(titles_ground_truth, titles):.3f}")
```

Recall: 0.931

We can also measure the baseline latency:

```
%timeit -n 1000 scann(np.array(["42"]), k=10)
```

4.67 ms ± 25 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Let's see if we can do better!

To do this, we need a model of how ScaNN's tuning knobs affect performance. Our current model uses ScaNN's tree-AH algorithm. This algorithm partitions the database of embeddings (the "tree") and then scores the most promising of these partitions using AH, which is a highly optimized approximate distance computation routine.

The default parameters for TensorFlow Recommenders' ScaNN Keras layer sets `num_leaves=100`

and `num_leaves_to_search=10`

. This means our database is partitioned into 100 disjoint subsets, and the 10 most promising of these partitions is scored with AH. This means 10/100=10% of the dataset is being searched with AH.

If we have, say, `num_leaves=1000`

and `num_leaves_to_search=100`

, we would also be searching 10% of the database with AH. However, in comparison to the previous setting, the 10% we would search will contain higher-quality candidates, because a higher `num_leaves`

allows us to make finer-grained decisions about what parts of the dataset are worth searching.

It's no surprise then that with `num_leaves=1000`

and `num_leaves_to_search=100`

we get significantly higher recall:

```
scann2 = tfrs.layers.factorized_top_k.ScaNN(
model.user_model,
num_leaves=1000,
num_leaves_to_search=100,
num_reordering_candidates=1000)
scann2.index_from_dataset(
tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))
)
_, titles2 = scann2(test_flat, k=10)
print(f"Recall: {compute_recall(titles_ground_truth, titles2):.3f}")
```

Recall: 0.966

However, as a tradeoff, our latency has also increased. This is because the partitioning step has gotten more expensive; `scann`

picks the top 10 of 100 partitions while `scann2`

picks the top 100 of 1000 partitions. The latter can be more expensive because it involves looking at 10 times as many partitions.

```
%timeit -n 1000 scann2(np.array(["42"]), k=10)
```

4.86 ms ± 21.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In general, tuning ScaNN search is about picking the right tradeoffs. Each individual parameter change generally won't make search both faster and more accurate; our goal is to tune the parameters to optimally trade off between these two conflicting goals.

In our case, `scann2`

significantly improved recall over `scann`

at some cost in latency. Can we dial back some other knobs to cut down on latency, while preserving most of our recall advantage?

Let's try searching 70/1000=7% of the dataset with AH, and only rescoring the final 400 candidates:

```
scann3 = tfrs.layers.factorized_top_k.ScaNN(
model.user_model,
num_leaves=1000,
num_leaves_to_search=70,
num_reordering_candidates=400)
scann3.index_from_dataset(
tf.data.Dataset.zip((lots_of_movies, lots_of_movies_embeddings))
)
_, titles3 = scann3(test_flat, k=10)
print(f"Recall: {compute_recall(titles_ground_truth, titles3):.3f}")
```

Recall: 0.957

`scann3`

delivers about a 3% absolute recall gain over `scann`

while also delivering lower latency:

```
%timeit -n 1000 scann3(np.array(["42"]), k=10)
```

4.58 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

These knobs can be further adjusted to optimize for different points along the accuracy-performance pareto frontier. ScaNN's algorithms can achieve state-of-the-art performance over a wide range of recall targets.

## Further reading

ScaNN uses advanced vector quantization techniques and highly optimized implementation to achieve its results. The field of vector quantization has a rich history with a variety of approaches. ScaNN's current quantization technique is detailed in this paper, published at ICML 2020. The paper was also released along with this blog article which gives a high level overview of our technique.

Many related quantization techniques are mentioned in the references of our ICML 2020 paper, and other ScaNN-related research is listed at http://sanjivk.com/