Intermediate

Processing Datasets

Transform datasets efficiently with map, filter, sort, shuffle, and batch operations. All processing is cached automatically and supports multiprocessing.

The map() Function

The map() function is the workhorse of dataset processing. It applies a function to every example (or batch) in the dataset:

Python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Tokenize each example
def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length")

tokenized = dataset.map(tokenize)

# Batched processing (much faster!)
def tokenize_batch(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length")

tokenized = dataset.map(tokenize_batch, batched=True, batch_size=1000)

# Parallel processing with multiple workers
tokenized = dataset.map(tokenize_batch, batched=True, num_proc=4)

Filtering

Python
# Keep only positive reviews
positive = dataset.filter(lambda x: x["label"] == 1)

# Filter by text length
long_texts = dataset.filter(lambda x: len(x["text"]) > 500)

# Batched filter (faster for large datasets)
filtered = dataset.filter(
    lambda batch: [len(t) > 100 for t in batch["text"]],
    batched=True
)

Common Operations

Python
# Shuffle
shuffled = dataset.shuffle(seed=42)

# Sort by a column
sorted_ds = dataset.sort("length")

# Select specific rows
subset = dataset.select(range(100))

# Rename columns
renamed = dataset.rename_column("label", "sentiment")

# Remove columns
cleaned = dataset.remove_columns(["unnecessary_col"])

# Add a new column
lengths = [len(t) for t in dataset["text"]]
dataset = dataset.add_column("length", lengths)

# Concatenate datasets
from datasets import concatenate_datasets
combined = concatenate_datasets([dataset1, dataset2])

# Train/test split
split = dataset.train_test_split(test_size=0.2, seed=42)

Format for Training

Python
# Set format for PyTorch
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# Now indexing returns PyTorch tensors
dataset[0]["input_ids"]  # tensor([101, 2023, ...])

# Create a PyTorch DataLoader
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Set format for TensorFlow
tf_dataset = dataset.to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["label"],
    batch_size=32,
    shuffle=True
)
Caching: All map() and filter() results are cached automatically. If you run the same operation again, it loads from cache instantly. Use load_from_cache_file=False to force recomputation.

Next: Streaming

Learn how to process datasets larger than memory using streaming mode and iterable datasets.

Next: Streaming →