cifar10数据集是比较难训练和分类识别的,对于GAN模型生成cifar10数据对其特性也是一个不小的挑战。如果GAN模型可以很好的生成 cifar10数据图片的话,这个GAN的生成能力已经是很不错的了。今天我们说说如何把下载下来的cifar10数据集输出到GAN模型中用作训练。

数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000 张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。 注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

这10类都是各自独立的,不会出现重叠。下载地址在这里

拿到这样的数据集后,我们要如何把数据和标签整理好喂入GAN模型或者识别测试的程序中呢?这时候就需要我们对数据集做一定的处理了。 我们看看程序代码:

def load_cifar10(dataset_name):
    data_dir = os.path.join("./data", dataset_name)

    X_train = []
    Y_train = []

    dirname = './data/cifar/cifar-10-batches-py'

    for i in range(1, 6):
        fpath = os.path.join(dirname, 'data_batch_' + str(i))
        data, labels = load_batch(fpath)
        if i == 1:
            X_train = data
            Y_train = labels
        else:
            X_train = np.concatenate([X_train, data], axis=0)
            Y_train = np.concatenate([Y_train, labels], axis=0)

    fpath = os.path.join(dirname, 'test_batch')
    X_test, Y_test = load_batch(fpath)

    X_train = np.dstack((X_train[:, :1024], X_train[:, 1024:2048],
                         X_train[:, 2048:])) / 255.
    X_train = np.reshape(X_train, [-1, 32, 32, 3])
    X_test = np.dstack((X_test[:, :1024], X_test[:, 1024:2048],
                        X_test[:, 2048:])) / 255.
    X_test = np.reshape(X_test, [-1, 32, 32, 3])

    X = np.concatenate((X_train, X_test), axis=0)

    # if one_hot:
    #    Y_train = to_categorical(Y_train, 10)
    #    Y_test = to_categorical(Y_test, 10)

    Y_test = np.asarray(Y_test)
    y = np.concatenate((Y_train, Y_test), axis=0).astype(np.int)

    seed = 547
    np.random.seed(seed) #确保每次生成的随机数相同
    np.random.shuffle(X) #将mnist数据集中数据的位置打乱
    np.random.seed(seed)
    np.random.shuffle(y)

    y_vec = np.zeros((len(y), 10), dtype=np.float)
    #创建了(70000,10)的标签记录,并且根据cifar10已有标签记录相应的10维数组
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1.0

    #返回归一化的数据和标签数组
    return X, y_vec

程序的大体上是从下载好的cifar10数据集将6个10000张图片先进行合并,将训练集和测试集分开。对于GAN而言可以将训练集和测试集合并。 然后再将数据集整理,将cifar10图片转换为32x32的彩色3通道图片,将图片放入X中保存。对于标签我们将10类数据转换成one_hot形式, 最后存放在y_vec中,我们在实现过程中打乱了数据集的数据位置,这样可以保证训练更加的随机。我们返回两个变量,一个是数据图片X 和对应的one_hot形式的标签y_vec。

在实际的训练过程中,这个形式的数据可以像我们之前讲解的训练mnist数据集 的代码使用,但是网络结构要进行一定的变换。具体的代码由于出文章的原因我暂时就不公布了,等我忙清了,我再分享到github上。

下面是我们训练GAN得到的实验效果图:

cifar10的生成效果还是一般般的,原因也是在于cifar10难于训练的,期待更好的模型完成图像生成。

可选择性参考label的处理

谢谢观看,希望对您有所帮助,欢迎指正错误,欢迎一起讨论!!!