Load datasets like a pro!

Load datasets like a pro!

Datasets

PyTorch Datasets and DataLoaders are the two commonly used classes in PyTorch. It facilitates a clean and efficient way to load data using generators. It means you don't have to load your entire dataset into the memory. Using these classes allows you to load the data only when they are needed during the training.

You must have already encountered Datasets and DataLoaders when you first learned PyTorch, as most examples use datasets that come with PyTorch. At first glance, this data structure can be very confusing. In real scenarios, the dataset do not come always come organized nicely as we would like. Some datasets have simple structure (for example a CSV file) while some data structure may have directories and sub-directories. In this article I will demonstrate how we can read the dataset if it is in some weird directory structure using Dataset class. In future articles, I will cover DataLoaders as well.

Kaggle dataset

We can take a look at this dataset from Kaggle. So, here's the lowdown: this dataset is like a treasure chest with two cool folders – 'pos' and 'neg.' 'Pos' is where the good vibes hang out, filled with text files bursting with positive movie reviews. On the flip side, 'neg' is the spot for a bit of drama and criticism, hosting reviews with a less-than-stellar outlook.

Picture this: you've got text files telling tales of movie experiences – from the super-happy "I love this movie" vibes to the not-so-impressed "Why did I waste my time?" moments. It's like peeking into a box of moviegoer emotions.

Now, we're not just here for the popcorn. Let's get down to business and build a PyTorch dataset to make sense of all these reviews. With our 'MovieReviewDataset' class, we're turning these text files into a playground for training models that can tell if a review is all sunshine and rainbows or more of a stormy weather situation.

Building our Dataset class

Importing necessary libraries

import os
from torch.utils.data import Dataset

Defining the class

def __init__(self, root_dir, transform=None):

Initializing the class

self.root_dir = root_dir
self.transform = transform
self.file_paths, self.labels = self._load_data()

Defining the len method

def __len__(self):
    return len(self.file_paths)

The __len__ method returns the total number of samples in the dataset.

Defining the getitem method

def __getitem__(self, idx):
    file_path = self.file_paths[idx]
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()

    label = self.labels[idx]

    if self.transform:
        text = self.transform(text)

    return {'text': text, 'label': label}

The __getitem__ method is responsible for loading and returning a sample from the dataset at the given index (idx). It reads the text from the file specified by the file path, retrieves the corresponding label, and applies the optional transformation. I will not cover the transform method in this article, so you do not need to worry about this. Assume this does nothing.

Defining the load_data method

def _load_data(self):
    file_paths = []
    labels = []
    for label, sentiment in enumerate(['neg', 'pos']):
        folder_path = os.path.join(self.root_dir, sentiment)
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            file_paths.append(file_path)
            labels.append(label)

    return file_paths, labels

The _load_data method populates lists of file paths and labels by iterating through 'neg' and 'pos' folders. This method is not mandatory like the __len__ method and __get_item__ method.

Putting everything to one place

import os
from torch.utils.data import Dataset

class MovieReviewDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths, self.labels = self._load_data()

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()

        label = self.labels[idx]

        if self.transform:
            text = self.transform(text)

        return {'text': text, 'label': label}

    def _load_data(self):
        file_paths = []
        labels = []
        for label, sentiment in enumerate(['neg', 'pos']):
            folder_path = os.path.join(self.root_dir, sentiment)
            for filename in os.listdir(folder_path):
                file_path = os.path.join(folder_path, filename)
                file_paths.append(file_path)
                labels.append(label)

        return file_paths, labels

We can now use this MovieReviewDataset class to load data.

root_directory = '/kaggle/input/movie-review-dataset/txt_sentoken/'
dataset = MovieReviewDataset(root_directory)