Datasets
PyTorch Datasets and DataLoaders are the two commonly used classes in PyTorch. It facilitates a clean and efficient way to load data using generators. It means you don't have to load your entire dataset into the memory. Using these classes allows you to load the data only when they are needed during the training.
You must have already encountered Datasets and DataLoaders when you first learned PyTorch, as most examples use datasets that come with PyTorch. At first glance, this data structure can be very confusing. In real scenarios, the dataset do not come always come organized nicely as we would like. Some datasets have simple structure (for example a CSV file) while some data structure may have directories and sub-directories. In this article I will demonstrate how we can read the dataset if it is in some weird directory structure using Dataset class. In future articles, I will cover DataLoaders as well.
Kaggle dataset
We can take a look at this dataset from Kaggle. So, here's the lowdown: this dataset is like a treasure chest with two cool folders – 'pos' and 'neg.' 'Pos' is where the good vibes hang out, filled with text files bursting with positive movie reviews. On the flip side, 'neg' is the spot for a bit of drama and criticism, hosting reviews with a less-than-stellar outlook.
Picture this: you've got text files telling tales of movie experiences – from the super-happy "I love this movie" vibes to the not-so-impressed "Why did I waste my time?" moments. It's like peeking into a box of moviegoer emotions.
Now, we're not just here for the popcorn. Let's get down to business and build a PyTorch dataset to make sense of all these reviews. With our 'MovieReviewDataset' class, we're turning these text files into a playground for training models that can tell if a review is all sunshine and rainbows or more of a stormy weather situation.
Building our Dataset class
Importing necessary libraries
import os
from torch.utils.data import Dataset
Defining the class
def __init__(self, root_dir, transform=None):
Initializing the class
self.root_dir = root_dir
self.transform = transform
self.file_paths, self.labels = self._load_data()
Defining the len method
def __len__(self):
return len(self.file_paths)
The __len__
method returns the total number of samples in the dataset.
Defining the getitem method
def __getitem__(self, idx):
file_path = self.file_paths[idx]
with open(file_path, 'r', encoding='utf-8') as file:
text = file.read()
label = self.labels[idx]
if self.transform:
text = self.transform(text)
return {'text': text, 'label': label}
The __getitem__
method is responsible for loading and returning a sample from the dataset at the given index (idx
). It reads the text from the file specified by the file path, retrieves the corresponding label, and applies the optional transformation. I will not cover the transform method in this article, so you do not need to worry about this. Assume this does nothing.
Defining the load_data method
def _load_data(self):
file_paths = []
labels = []
for label, sentiment in enumerate(['neg', 'pos']):
folder_path = os.path.join(self.root_dir, sentiment)
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
file_paths.append(file_path)
labels.append(label)
return file_paths, labels
The _load_data
method populates lists of file paths and labels by iterating through 'neg' and 'pos' folders. This method is not mandatory like the __len__
method and __get_item__
method.
Putting everything to one place
import os
from torch.utils.data import Dataset
class MovieReviewDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.file_paths, self.labels = self._load_data()
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
with open(file_path, 'r', encoding='utf-8') as file:
text = file.read()
label = self.labels[idx]
if self.transform:
text = self.transform(text)
return {'text': text, 'label': label}
def _load_data(self):
file_paths = []
labels = []
for label, sentiment in enumerate(['neg', 'pos']):
folder_path = os.path.join(self.root_dir, sentiment)
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
file_paths.append(file_path)
labels.append(label)
return file_paths, labels
We can now use this MovieReviewDataset class to load data.
root_directory = '/kaggle/input/movie-review-dataset/txt_sentoken/'
dataset = MovieReviewDataset(root_directory)