classDataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """
classCustomDataset(torch.utils.data.Dataset):#需要继承data.Dataset def__init__(self): # TODO # 1. Initialize file path or list of file names. # 做一些初始化的工作 pass def__getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). # 这里需要注意的是,第一步:read one data,是一个data pass def__len__(self): # You should change 0 to the total size of your dataset. return0
classMNIST(data.Dataset): """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. Args: root (string): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ def__init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set
if download: self.download()
ifnot self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
def__getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L')
if self.transform isnotNone: img = self.transform(img)
if self.target_transform isnotNone: target = self.target_transform(target)
return img, target
def__len__(self): if self.train: returnlen(self.train_data) else: returnlen(self.test_data)
def_check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
defdownload(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" # PASS
def__getitem__(self, index): if self.train: img, target = self.data.iloc[index:index+1, 1:].values.reshape(28, 28).astype(np.uint8), self.data.iloc[index:index+1, :1].values[0].tolist()[0] else: img, target = self.data.iloc[index:index + 1, :].values.reshape(28, 28).astype(np.uint8), \ -1 # doing this so that it is consistent with all other datasets # to return a PIL Image # print(target, type(target)) img = Image.fromarray(img)
if self.transform isnotNone: img = self.transform(img)
if self.target_transform isnotNone: target = self.target_transform(target) return img, target
classDataLoader(object): r""" Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: ``1``). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: ``0``) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``) worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) .. note:: By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``, where ``base_seed`` is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See :ref:`dataloader-workers-random-seed` section in FAQ.) You may use :func:`torch.initial_seed()` to access the PyTorch seed for each worker in :attr:`worker_init_fn`, and use it to set other seeds before data loading. .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. """
def__setattr__(self, attr, val): if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__))
''' Created on 2018-10-29 Update on 2018-10-29 Author: SuooL Github: https://github.com/SuooL '''
import pandas as pd import numpy as np import os import torch.nn as nn import torch.utils.data import torchvision.transforms as transforms import csv from PIL import Image
# 多核 cpu 设置 os.environ["OMP_NUM_THREADS"] = "8" os.environ["MKL_NUM_THREADS"] = "8"