Dataset和DataLoader的书写

  • 时间:
  • 浏览:
  • 来源:互联网

文章目录

  • 书写Dataset和DataLoader
  • 一、导入必要的库
  • 二、继承Dataset自定义数据读写
  • 三、写DataLoader完成batch和shuffle等操作
  • 四、调用方法以及打印一些值
    • 1. 数据集的长度
    • 2. train_loader的长度
    • 3. crop的作用
  • 五、 参考


书写Dataset和DataLoader

说明:搭建网络需要数据的读入和预处理,model的构建,损失函数和优化策略的选择。该博客以PIL读取图像,构建最基本的Dataset和DataLoader。


一、导入必要的库

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as tr

import os 
from PIL import Image 

用PIL读取图像为例,还可以用matplotlib和cv2读取图像

二、继承Dataset自定义数据读写

class MyDataset(Dataset):
    def __init__(self,root_path):
        filenames = sorted(os.listdir(root_path))#filenames是列表 相当于存储的路径下的所有子文件名(查阅os.listdir)
        self.transforms = tr.Compose([
            tr.CenterCrop(256),#剪裁成等大的图像块 因为之后batch操作时必须保证每个batch中图像大小一致(尝试注释掉这行,看报错信息)
            tr.ToTensor(),#将数据类型转换为tensor 之后才能加载到gpu上训练
            tr.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))#这个是标准化处理,看具体数据情况
        ])
        
        self.dataset=[]
        for name in filenames:
            file=os.path.join(root_path,name)#就相当于根路径和文件名连在一起
            image=Image.open(file).convert('RGB')#用PIL读取并转换为RGB格式
            self.dataset.append(self.transforms(image))#调用之前的变换(私有属性)不再赘述
        
    def __len__(self):                     #这两个函数必须在Dataset里实现,一个返回数据集长度,一个通过索引来访问数据集中每个数据
        return len(self.dataset)
        
    def __getitem__(self,idx):
        return self.dataset[idx]

三、写DataLoader完成batch和shuffle等操作

root_path=r'/path/to/data'#换成你自己数据集存放位置(本博客用的数据集是Set14,由14张图片组成)

train_dataset=MyDataset(root_path)#实例化上述类
train_loader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True)#加载数据集,完成dataset里无法对数据进行的分批等操作

四、调用方法以及打印一些值

1. 数据集的长度

print(len(train_dataset))
# print(train_dataset.__len__())
# print(train_dataset.__getitem__(0))

结果 :14

如果我们注释掉第二部分中的__len__()这个方法后会报如下的错:

TypeError: object of type 'MyDataset' has no len()

2. train_loader的长度

print(len(train_loader))

结果: 4
那么为什么train_loader的长度是4? 我们将其中的元素大小打印出来:

for a in train_loader:
    print(a.size())

结果如下:

torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256])
torch.Size([2, 3, 256, 256])

我们看看train_loader里到底存了什么,是4 * 3 * 256 * 256大小的张量(高维矩阵)。首先我们知道数据集里有14张图片,我们在DataLoader里设置的batch_size大小为4,没错第一维就是这个批的大小,因为有14张图片,4个为一批,所以最后会有两个剩余,默认人情况下会把不足的看作一批,当然也可以用DataLoader的参数使得不足的忽略掉,自己去查;第二维存的是通道数,我们读入的图片是RGB格式所以是三通道的,灰度图是单通道的,最后两个是图像的H和W(即高和宽)——原图像H和W不一定一致哦,这是我们crop(剪裁)的结果,为了进行批处理。

3. crop的作用

如果我们将剪裁的那行代码注释掉,最后在访问train_loader中的元素时会报如下的错:

RuntimeError: stack expects each tensor to be equal size, but got [3, 512, 512] at entry 0 and [3, 361, 250] at entry 1

也就是说,在分批处理的时候,一定要保证每批中的数据大小一致。

五、 参考

详细内容可参考pytorch官方文档:
torch.utils.data

本文链接http://www.dzjqx.cn/news/show-617486.html