SuooL's Blog

蛰伏于盛夏 藏华于当春

pytorch 自定义数据集及 Kaggle 101 数字识别

引言

前面的文章中使用 feed forward neural network 实现了简单的手写数字识别,但是这不能直接照搬到 kaggle上面,因为 kaggle 使用的数据集是 CSV 文件,因此需要自定义一个 pytorch 的数据类型,来完成这个入门题目。

本文的提纲如下:

  1. 自定义 Dataset
  2. 模型搭建保存与读取

自定义 Dataset

Pytorch的数据读取主要包含三个类:

  1. Dataset
  2. DataLoader
  3. DataLoaderIter

这三者大致是一个依次封装的关系: Dataset 被封装进DataLoader, DataLoader 被装进 DataLoaderIter。

torch.utils.data.Dataset

这个类的源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Dataset(object):
"""An abstract class representing a Dataset.

All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""

def __getitem__(self, index):
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

其类的说明在上述源码的注释中一目了然,这是一个抽象类, 自定义的Dataset需要继承它并且实现下面两个成员方法:

  1. __getitem__() 方法
  2. __len__() 方法

自定义类的框架如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class CustomDataset(torch.utils.data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
# 做一些初始化的工作
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
# 这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0

以下是官方的 MNIST 的一个参考例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class MNIST(data.Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set

if download:
self.download()

if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')

if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)

def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
# PASS

def __repr__(self):
# PASS
return fmt_str

第一个最为重要, 即每次怎么读数据;上面的读取图片的例子中可以看出如何实现这个方法,值得一提的是, pytorch 还提供了很多常用的 transform, 在 torchvision.transforms 里面, 本文中不多介绍, 我常用的有Resize , RandomCrop , Normalize , ToTensor (这个极为重要, 可以把一个 PIL或numpy 图片转为 torch.Tensor, 但是好像对 numpy 数组的转换比较受限, 所以这里建议在 __getitem__() 里面用 PIL 来读图片, 而不是用 skimage.io)。

第二个比较简单, 就是返回整个数据集的长度。

按照这个思路,我写的Version 1.0 版的如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class DigitDataSet(torch.utils.data.Dataset):
def __init__(self, data_path, train=True, transform=None, target_transform=None):
self.transform = transform
self.train = train
self.target_transform = target_transform
self.data = pd.read_csv(data_path, header=0, sep=',')

def __getitem__(self, index):
if self.train:
img, target = self.data.iloc[index:index+1, 1:].values.reshape(28, 28).astype(np.uint8), self.data.iloc[index:index+1, :1].values[0].tolist()[0]
else:
img, target = self.data.iloc[index:index + 1, :].values.reshape(28, 28).astype(np.uint8), \
-1
# doing this so that it is consistent with all other datasets
# to return a PIL Image
# print(target, type(target))
img = Image.fromarray(img)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)
return img, target

def __len__(self):
return len(self.data)

这是一个显然不够好看的实现方式,存在一些问题,比如没办法划分训练集和验证集,只能读取训练集和测试集。有点模型训练完了,一切看天的意思。

改进版会在下次再说,这次纯粹是为了解决问题。

torch.utils.data.DataLoader

这个类中的源代码中的注释非常详细,在此直接贴下来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.

Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with :attr:`batch_size`,
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)

.. note:: By default, each worker will have its PyTorch seed set to
``base_seed + worker_id``, where ``base_seed`` is a long generated
by main process using its RNG. However, seeds for other libraies
may be duplicated upon initializing workers (w.g., NumPy), causing
each worker to return identical random numbers. (See
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
use :func:`torch.initial_seed()` to access the PyTorch seed for
each worker in :attr:`worker_init_fn`, and use it to set other
seeds before data loading.

.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
unpicklable object, e.g., a lambda function.
"""

__initialized = False

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
# PASS
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True

def __setattr__(self, attr, val):
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))

super(DataLoader, self).__setattr__(attr, val)

def __iter__(self):
return _DataLoaderIter(self)

def __len__(self):
return len(self.batch_sampler)

从他的 __init__ 方法中可以看出可以看到, 主要参数有这么几个:

  • dataset : 即上面自定义的 dataset.
  • collate_fn: 这个函数用来打包 batch.
  • num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据

这个类其实就是下面将要讲的 DataLoaderIter 的一个框架, 一共干了两件事:

  1. 定义了一堆成员变量, 到时候赋给 DataLoaderIter,
  2. 然后有一个 __iter__() 函数, 把自己 “装进” DataLoaderIter 里面.
1
2
def __iter__(self):
return DataLoaderIter(self)

torch.utils.data.dataloader.DataLoaderIter

上面提到, DataLoader 就是DataLoaderIter的一个框架, 用来传给DataLoaderIter 一堆参数, 并把自己装进DataLoaderIter 里.

比如下面一个框架:

1
2
3
4
5
6
7
8
class CustomDataset(Dataset):
# 自定义自己的dataset

dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)

for data in dataloader:
# training...

在for 循环里, 总共有三点操作:

  • 调用了 dataloader__iter__() 方法, 产生了一个DataLoaderIter
  • 反复调用 DataLoaderIter__next__() 来得到 batch, 具体操作就是, 多次调用 dataset 的 __getitem__() 方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及 sample 的方法等, 这里就不多说了.
  • 数据读完后, __next__() 抛出一个 StopIteration 异常, for循环结束, dataloader 失效.

模型搭建保存与读取

模型的搭建及保存部分和上篇文章的一模一样,这里就不多说了。

关于模型的读取预测部分代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/python
# coding: utf-8

'''
Created on 2018-10-29
Update on 2018-10-29
Author: SuooL
Github: https://github.com/SuooL
'''

import pandas as pd
import numpy as np
import os
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import csv
from PIL import Image

# 多核 cpu 设置
os.environ["OMP_NUM_THREADS"] = "8"
os.environ["MKL_NUM_THREADS"] = "8"

# 设置使用 CPU
device = torch.device('cpu')

# 参数配置
input_size = 784
hidden_size = 512
num_classes = 10

# 3 Fully connected neural network with one hidden layer 定义网络
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out


model = NeuralNet(input_size, hidden_size, num_classes).to(device)
model_dict = model.load_state_dict(torch.load('model.ckpt'))

# Test the model 预测
# In test phase, we don't need to compute gradients (for memory efficiency)

test_data = pd.read_csv('test.csv', header=0, sep=',')

with torch.no_grad():
with open('submission.csv', 'w', newline='') as csv_file:
writer = csv.writer(csv_file, dialect='excel')
writer.writerow(["ImageId", "Label"])
for index in range(0, 28000):
img = Image.fromarray(test_data.iloc[index:index+1, :].values.reshape(28, 28).astype(np.uint8))
transform1 = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
])
img_tensor = transform1(img)
images = img_tensor.reshape(-1, 28 * 28).to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
# writer.writerow([])
print("picture %d, predicted number is %s" % (index+1, predicted[0].item()))
writer.writerow([index+1, predicted[0].item()])

总结

这篇文章基本上上熟悉 pytorch 自定义数据集相关的知识,代码实现并没有过多去关注,算是完成了 kaggle 第一题的尝试。

以后会更新下这个代码,实现的更优雅一些。

下篇文章会写一些关于 NLP 处理的基础经典的算法使用。

泡面一杯