Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ml_dataloader_resampling.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of resampling when one class is underrepresented in the dataset.
5###
6### \macro_code
7### \macro_output
8### \author Jonah Ascoli
9
10import ROOT
11import torch
12
13seed = 42
15
16
17# Create an imbalanced dataset with two classes, one of which is underrepresented.
18# Here, we'll create two files, one with even numbers and one with odd numbers,
19# and then merge them to form a dataset with underrepresented odd numbers.
20def make_df(b1_expr, num_events):
21 return ROOT.RDataFrame(num_events).Define("b1", b1_expr).Define("b2", "(int) b1%2")
22
23
24df_major = make_df("(int) 2 * rdfentry_", 100000)
25df_minor = make_df("(int) 2 * rdfentry_ + 1", 1000)
26
27batch_size = 256
28num_epochs = 20
29
31
32
33# Function to train the model and print useful loss statistics
34def train_model(model, optimizer, dataloader):
35 train, val = dataloader.train_test_split(test_size=0.2)
36 for _ in range(num_epochs):
37 train_correct = 0
38 train_total = 0
39 train_losses = []
41 for X, y in train.as_torch():
43 outputs = model(X)
44 loss = loss_fn(outputs, y)
47
48 preds = (outputs > 0.5).float()
49 train_correct += (preds == y).sum().item()
50 train_total += y.size(0)
52 print(
53 f"Training => Accuracy: {int(train_correct / train_total * 100000) / 100000}; Loss: {int(sum(train_losses) / len(train_losses) * 100000) / 100000}"
54 )
55 val_losses = []
56 val_correct = 0
57 val_total = 0
58 for X, y in val.as_torch():
59 with torch.no_grad():
60 outputs = model(X)
61 loss = loss_fn(outputs, y)
62
63 preds = (outputs > 0.5).float()
64 val_correct += (preds == y).sum().item()
65 val_total += y.size(0)
67
68 print(
69 f"Validation => Accuracy: {int(val_correct / val_total * 100000) / 100000}; Loss: {int(sum(val_losses) / len(val_losses) * 100000) / 100000}\n"
70 )
71
72
73# Oversampling strategy: more batches of the underrepresented class
74# Takes more time per epoch, but each epoch is more balanced & effective
76 [df_major, df_minor],
77 batch_size=batch_size,
78 target="b2",
79 set_seed=seed,
80 load_eager=True, # Must be enabled for resampling
81 sampling_type="oversampling", # Can also be "undersampling"
82 sampling_ratio=0.1, # ~10% of the data will be from the underrepresented class, instead of ~1%
83)
84
85oversampling_model = torch.nn.Linear(1, 1)
86oversampling_optimizer = torch.optim.Adam(oversampling_model.parameters())
87
88print("Training with oversampling:")
89train_model(oversampling_model, oversampling_optimizer, dl_oversampled)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
static uint64_t sum(uint64_t i)
Definition Factory.cxx:2338