Loading Data for Model Training

LanceDB makes an excellent data backend for training machine learning models. This section will describe the overall process of training a model and explain how to use LanceDB as a data backend for training.

Data Loading

Most model training frameworks iterate through data in batches and feed this data into the model. This process is often referred to as “data loading”. The simplest way to load data into a model is to iterate a LanceDB table in a loop and feed the data into the model.

python
import lancedb

db = lancedb.connect("file://some/db/path")
table = db.open_table("some_table")

for batch in table:
    print(batch.to_pydict())

In practice, this is too simplistic for effective training. We may not want to load all the data, or we may want to load the data in a different order, or we may need to apply some sort of processing to the data before training. To achieve this, we will often want to create a Permutation of the table.

Selecting Rows

When training a model, we might not want to load all of the data. For example, we might filter out columns that are not needed for training. We might also divide the data into training and validation sets. Or we could divide the data into multiple sets for cross-validation.

Whenever we create a permutation of the table, we have to first decide which rows we want to include (and in what order). This is stored in a “permutation table” which marks out the row ids that make up our data. Other decisions, such as which columns to include, and what transformations to apply, can be defined at read time and don’t require a separate permutation table.

💡 Permutation tables are tables, just like any other table in LanceDB. By default, they are stored in memory but they can be persisted to storage as well. This is useful when you want to share a permutation table across processes or nodes.

Selecting All Rows

To select all rows, we can use the Permutation.identity method. This gives us a Permutation without requiring us to create a separate permutation table. This allows us to refine our columns and apply transformations and can be useful when the data loader itself is responsible for handling sampling and shuffling.

python
from lancedb.permutation import Permutation

# We can create an identity permutation without needing any separate permutation table.
permutation = Permutation.identity(table)

# This allows us to refine our columns and apply transformations
permutation = permutation.select_columns(["id", "prompt"])

Filtering Rows

If we only want to select a subset of rows, then we can use a filter. This will require us to create a permutation table which identifies which rows we want to include.

python
from lancedb.permutation import Permutation, permutation_builder

# We can create a permutation table which identifies which rows we want to include.
permutation_tbl = permutation_builder(table).filter("category = 'cat'").execute()

# We can then use this permutation table to create a Permutation object
permutation = Permutation.from_tables(table, permutation_tbl)

Creating Splits

LanceDB also provides several different methods for creating splits. These allow us to divide our dataset into smaller non-overlapping sets. The split can then be specified when creating the Permutation object to view only a subset of the data.

python
from lancedb.permutation import Permutation, permutation_builder

# Here we create two splits, one for training and one for validation.  By default, splits have no
# name and are accessed by index.
permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05]).execute()

# Let's create a permutation object which views only the training data.
permutation = Permutation.from_tables(table, permutation_tbl, split=0)

# Splits can also be given names.  The names can then be used later to access the split instead of
# requiring us to know the index.
permutation_tbl = permutation_builder(table).split_random(ratios=[0.95, 0.05], split_names=["train", "test"]).execute()
permutation = Permutation.from_tables(table, permutation_tbl, split="train")

Shuffling Rows

By default, permutations will access the data in the order the data is stored in the table. This can cause our model to learn artifacts specific to the order of the data. This is one of many ways we can “overfit” our model to our data. To avoid this, we typically want to shuffle the data before training. Model training frameworks (like PyTorch) will often provide a way to shuffle the data. If you are not using one of these frameworks, or if you want to shuffle the data with LanceDB, you can shuffle the rows when you create a permutation table.

python
from lancedb.permutation import Permutation, permutation_builder

# We can shuffle the rows when we create the permutation table.
permutation_tbl = permutation_builder(table).shuffle().execute()

# We can then use this permutation table to create a Permutation object, this will now
# access the data in a random order.
permutation = Permutation.from_tables(table, permutation_tbl)

Selecting Columns

By default, permutations will return all columns in the table. If you only need a subset of the columns, you can significantly reduce your I/O requirements by selecting only the columns you need. This can be done on the permutation object itself, and does not require us to create a separate permutation table.

python
from lancedb.permutation import Permutation

# We can select only the columns we need.
permutation = Permutation.identity(table).select_columns(["id", "prompt"])