Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ml_dataloader_resampling.py File Reference

Detailed Description

View in nbviewer Open in SWAN
Example of resampling when one class is underrepresented in the dataset.

import ROOT
import torch
seed = 42
# Create an imbalanced dataset with two classes, one of which is underrepresented.
# Here, we'll create two files, one with even numbers and one with odd numbers,
# and then merge them to form a dataset with underrepresented odd numbers.
def make_df(b1_expr, num_events):
return ROOT.RDataFrame(num_events).Define("b1", b1_expr).Define("b2", "(int) b1%2")
df_major = make_df("(int) 2 * rdfentry_", 100000)
df_minor = make_df("(int) 2 * rdfentry_ + 1", 1000)
batch_size = 256
num_epochs = 20
# Function to train the model and print useful loss statistics
def train_model(model, optimizer, dataloader):
train, val = dataloader.train_test_split(test_size=0.2)
for _ in range(num_epochs):
train_correct = 0
train_total = 0
train_losses = []
for X, y in train.as_torch():
outputs = model(X)
loss = loss_fn(outputs, y)
preds = (outputs > 0.5).float()
train_correct += (preds == y).sum().item()
train_total += y.size(0)
print(
f"Training => Accuracy: {int(train_correct / train_total * 100000) / 100000}; Loss: {int(sum(train_losses) / len(train_losses) * 100000) / 100000}"
)
val_losses = []
val_correct = 0
val_total = 0
for X, y in val.as_torch():
with torch.no_grad():
outputs = model(X)
loss = loss_fn(outputs, y)
preds = (outputs > 0.5).float()
val_correct += (preds == y).sum().item()
val_total += y.size(0)
print(
f"Validation => Accuracy: {int(val_correct / val_total * 100000) / 100000}; Loss: {int(sum(val_losses) / len(val_losses) * 100000) / 100000}\n"
)
# Oversampling strategy: more batches of the underrepresented class
# Takes more time per epoch, but each epoch is more balanced & effective
[df_major, df_minor],
batch_size=batch_size,
target="b2",
set_seed=seed,
load_eager=True, # Must be enabled for resampling
sampling_type="oversampling", # Can also be "undersampling"
sampling_ratio=0.1, # ~10% of the data will be from the underrepresented class, instead of ~1%
)
oversampling_model = torch.nn.Linear(1, 1)
print("Training with oversampling:")
train_model(oversampling_model, oversampling_optimizer, dl_oversampled)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2338
Training with oversampling:
Training => Accuracy: 0.09089; Loss: 54054.4592
Training => Accuracy: 0.09089; Loss: 22879.47853
Training => Accuracy: 0.71861; Loss: 852.54963
Training => Accuracy: 0.9091; Loss: 1.21655
Training => Accuracy: 0.9091; Loss: 1.0924
Training => Accuracy: 0.9091; Loss: 0.938
Training => Accuracy: 0.9091; Loss: 0.74917
Training => Accuracy: 0.9091; Loss: 0.52145
Training => Accuracy: 0.90906; Loss: 0.2629
Training => Accuracy: 0.90798; Loss: 0.09193
Training => Accuracy: 0.90683; Loss: 0.07532
Training => Accuracy: 0.90621; Loss: 0.07395
Training => Accuracy: 0.90539; Loss: 0.07244
Training => Accuracy: 0.90465; Loss: 0.07081
Training => Accuracy: 0.90397; Loss: 0.06906
Training => Accuracy: 0.90324; Loss: 0.06723
Training => Accuracy: 0.9283; Loss: 0.06534
Training => Accuracy: 0.99265; Loss: 0.06342
Training => Accuracy: 0.99201; Loss: 0.06152
Training => Accuracy: 0.99136; Loss: 0.05965
Validation => Accuracy: 0.99016; Loss: 0.05549
Author
Jonah Ascoli

Definition in file ml_dataloader_resampling.py.