Refactor loading of data to use a data loader
I would suggest wrapping the data loading into a pytorch data loader and just a sampler for shuffling and batching.
One needs to
- Implement a dataset that takes care of the loading tokenization and returns instances. I think a map-style dataset that returns dict/object (like
SentenceEvidence
) - Implement a
collate_fn
that takes a list of these object and returns a batch - Then one can just use:
def collate_and_padd_batch(...):
...
dataset = EraserDataset(..., split='train')
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=collate_and_padd_batch)
for batch in loader:
# to training
...
For more information see