前面的文章已经将GAN的数据处理,GAN的网络搭建以及损失函数的设计都完成了,就差训练和出结果这一步了,本文将对GAN模型进行训练。

前面已经将繁琐的部分解决了,接下来就差训练了。训练GAN模型其实很简单,就是开启tensorflow数据流让网络参数随着训练步数更新优化即可。我们先 贴上代码再分析。

def train(self):

    # initialize all variables初始化各个变量
    tf.global_variables_initializer().run()

    # graph inputs for visualize training results
    #创造噪声z,GAN中应用的为均值分布,创造(64,62)大小的-1到1之间的
    self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))

    # saver to save model 将训练好的模型参数保存起来
    self.saver = tf.train.Saver()

    # summary writer 将训练记录在log下
    self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name,
                                        self.sess.graph)

    # restore check-point if it exits
    could_load, checkpoint_counter = self.load(self.checkpoint_dir)
    if could_load:
        start_epoch = int(checkpoint_counter / self.num_batches)
        start_batch_id = checkpoint_counter - start_epoch * self.num_batches
        counter = checkpoint_counter
        print(" [*] Load SUCCESS")
    else:
        start_epoch = 0
        start_batch_id = 0
        counter = 1
        print(" [!] Load failed...")

    # loop for epoch
    start_time = time.time()
    for epoch in range(start_epoch, self.epoch):

        # get batch data
        # 由于batchsize为64,遍历70000张图片需要1093次
        for idx in range(start_batch_id, self.num_batches):
            #提取处理好的固定位置图片,data_X的按批次处理后的图片位置,一个批次64张图片
            batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]) \
                                        .astype(np.float32)
            batch_images = self.data_X[idx * self.batch_size:(idx + 1) \
                                        * self.batch_size]

            # update D network sess.run喂入数据优化更新D网络,并在tensorboard中更新
            _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum,
                                     self.d_loss],feed_dict={
                                     self.inputs: batch_images, self.z: batch_z})
            self.writer.add_summary(summary_str, counter)

            # update G network sess.run喂入数据优化更新G网络,并在tensorboard中更新
            _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum,
                                        self.g_loss],feed_dict={self.z: batch_z})
            self.writer.add_summary(summary_str, counter)

            # display training status
            counter += 1
            #训练一个batchsize打印一下loss,一个epoch打印1093次我认为没这个必要,
            #50次batchsize后打印一下
            if np.mod(counter, 50) == 0:
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss= %.8f,
                        g_loss= %.8f" % (epoch, idx, self.num_batches,
                        time.time() - start_time,d_loss, g_loss))

            # save training results for every 300 steps 训练300步保存一张图片
            if np.mod(counter, 300) == 0:
                #生成一张该阶段下的由生成器生成的“假图片”
                samples = self.sess.run(self.fake_images,
                                        feed_dict={self.z: self.sample_z})
                #此处计算生成图片的小框图片的排布,本处为8×8排布
                tot_num_samples = min(self.sample_num, self.batch_size)
                manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                save_images(samples[:manifold_h * manifold_w, :, :, :],
                            [manifold_h, manifold_w],'./'
                            + check_folder(self.result_dir
                            + '/' + self.model_dir) + '/' + self.model_name
                            +'_train_{:02d}_{:04d}.png'.format(epoch, idx))

        # After an epoch, start_batch_id is set to zero 经过一个epoch后start_batch_id置为0
        # non-zero value is only for the first epoch after loading pre-trained model
        start_batch_id = 0

        # save model
        self.save(self.checkpoint_dir, counter)

        # show temporal results 经过一个epoch后输出一张全类别的图片,用于比较epoch后的差别
        self.visualize_results(epoch)

    # save model for final step 当epoch全部训练完后保存checkpoint
    self.save(self.checkpoint_dir, counter)

上述代码描述了:

  • 首先初始化变量
  • 创造噪声z用于输入G中产生fake图片
  • 加载checkpoint用于检查训练网络
  • 开始循环训练
  • 按批次加载数据和标签
  • 更新D网络,更新G网络
  • 可视化结果处理
  • 保存网络参数

具体的代码我就不一行行的去解释了,其中涉及到一些地址保存和tensorboard的内容我就不详细展开了,详细的就看看我的github的完整代码吧。至于可 视化这一部分我可能会出一篇文章单独说说。至此GAN网络的搭建已经结束了,至于实验结果可以直接到我的github上看结果就好了,我就不把训练结果贴 上来了,大家有什么问题可以一起讨论,前提是我有时间,哈哈!

我的GANs的完整代码:

tensorflow-GANs

pytorch-GANs

大家感觉可以就在github项目上点一下关注,哈哈

谢谢观看,希望对您有所帮助,欢迎指正错误,欢迎一起讨论!!!