Deep Reinforcement Learning for Snake

沥川 bio photo By 沥川

Deep Reinforcement Learning (Deep Q Network)

Snake Game

效果展示 Frames 11000000(56h)

增强学习要解决的问题

  • 不要人教,AI 如何通过自身的行为和反馈学会玩游戏?

增强学习基本思路

增强学习两大问题

  • credit assignment problem
  • explore-exploit dilemma

两大问题解决方法

  • Discounted Future Reward
  • Q-learning

  • ε-greedy exploration

Deep Q network 算法

CNN 结构

第一部分:游戏设计

Reward 减分

# check if the worm has hit itself or the edge
        reward = -0.1
        done = False
        if self.wormCoords[HEAD]['x'] == -1 or self.wormCoords[HEAD]['x'] == CELLWIDTH or self.wormCoords[HEAD]['y'] == -1 or self.wormCoords[HEAD]['y'] == CELLHEIGHT:
            done = True
            #self.__init__() # game over
            reward = -1
            image_data = pygame.surfarray.array3d(pygame.display.get_surface())
            return image_data, reward, done
        for self.wormBody in self.wormCoords[1:]:
            if self.wormBody['x'] == self.wormCoords[HEAD]['x'] and self.wormBody['y'] == self.wormCoords[HEAD]['y']:
                done = True
                #self.__init__() # game over
                reward = -1
                image_data = pygame.surfarray.array3d(pygame.display.get_surface())
                return image_data, reward, done

Reward 加分

# check if worm has eaten an apple
        if self.wormCoords[HEAD]['x'] == self.apple['x'] and self.wormCoords[HEAD]['y'] == self.apple['y']:
            # don't remove worm's tail segment
            self.apple = self.getRandomLocation(self.wormCoords) # set a new apple somewhere
            reward = 2
            self.totalscore = self.totalscore + 1
        else:
            del self.wormCoords[-1] # remove worm's tail segment

第二部分:建立模型

全局变量初始化

self.epoch = 0         # 步
self.episode = 0       # 事件
self.observe = 10000   # 观察期
self.discft = 0.99     # 折现因子
self.epsilon = 1       # 初始 epsilon
self.finep = 0.001     # 最终 epsilon
self.REPLAYMEM = 50000 # replay memory 大小
self.batchsize = 32    # 样本大小
self.actions = 4       # 行动种类
self.repmem = deque()  # 双端队列保存 memory

初始化 CNN 参数(三层)

# Init weight and bias
        self.w1 = tf.Variable(tf.truncated_normal([8, 8, 4, 32], stddev = 0.01))
        self.b1 = tf.Variable(tf.constant(0.01, shape = [32]))

        self.w2 = tf.Variable(tf.truncated_normal([4, 4, 32, 64], stddev=0.01))
        self.b2 = tf.Variable(tf.constant(0.01, shape = [64]))

        self.w3 = tf.Variable(tf.truncated_normal([3, 3, 64, 64], stddev = 0.01))
        self.b3 = tf.Variable(tf.constant(0.01, shape = [64]))

        self.wfc = tf.Variable(tf.truncated_normal([2304, 512], stddev = 0.01))
        self.bfc = tf.Variable(tf.constant(0.01, shape = [512]))

        self.wto = tf.Variable(tf.truncated_normal([512, self.actions], stddev = 0.01))
        self.bto = tf.Variable(tf.constant(0.01, shape = [self.actions]))

构建 CNN 模型

def initConvNet(self):
        tf.nn.conv2d(self.input, self.w1, strides = [1, 4, 4, 1], padding = "SAME")
        conv1 = tf.nn.relu(tf.nn.conv2d(self.input, self.w1, strides = [1, 4, 4, 1], padding = "SAME") + self.b1)
        pool = tf.nn.max_pool(conv1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME")

        conv2 = tf.nn.relu(tf.nn.conv2d(pool, self.w2, strides = [1, 2, 2, 1], padding = "SAME") + self.b2)

        conv3 = tf.nn.relu(tf.nn.conv2d(conv2, self.w3, strides = [1, 1, 1, 1], padding = "SAME") + self.b3)

        conv3_to_reshaped = tf.reshape(conv3, [-1, 2304])
        
        fullyconnected = tf.nn.relu(tf.matmul(conv3_to_reshaped, self.wfc) + self.bfc)

        self.output = tf.matmul(fullyconnected, self.wto) + self.bto

损失函数及优化器

def initNN(self):
        self.a = tf.placeholder("float", [None, self.actions])
        self.y = tf.placeholder("float", [None]) 
        out_action = tf.reduce_sum(tf.multiply(self.output, self.a), reduction_indices = 1)
        self.cost = tf.reduce_mean(tf.square(self.y - out_action))
        self.optimize = tf.train.AdamOptimizer(1e-6).minimize(self.cost)
        
        self.saver = tf.train.Saver()
        self.session = tf.InteractiveSession()
        self.session.run(tf.global_variables_initializer())
        checkpoint = tf.train.get_checkpoint_state("saved")
        # For fresh start, comment below 2 lines
        if checkpoint and checkpoint.model_checkpoint_path:
            self.saver.restore(self.session, checkpoint.model_checkpoint_path)

训练模型

def train(self):
        # DQN
        minibatch = random.sample(self.repmem, self.batchsize)
        s_batch = [data[0] for data in minibatch]
        a_batch = [data[1] for data in minibatch]
        r_batch = [data[2] for data in minibatch]
        s_t1_batch = [data[3] for data in minibatch]

        y_batch = []
        Q_batch = self.output.eval(feed_dict={self.input : s_t1_batch})
        for i in range(0,self.batchsize):
            done = minibatch[i][4]
            if done:
                y_batch.append(r_batch[i])
            else:
                y_batch.append(r_batch[i] + self.discft * np.max(Q_batch[i]))

        self.optimize.run(feed_dict={self.y : y_batch, self.a : a_batch, self.input : s_batch})

        if self.epoch % 30000 == 0:
            self.saver.save(self.session, 'saved/' + 'snake.ckpt', global_step = self.epoch)           

获取 Action (epsilon greedily 贪心算法)

def getAction(self):
        Q_val = self.output.eval(feed_dict={self.input : [self.s_t]})[0]
        # for print
        self.qv = Q_val
        # action array
        action = np.zeros(self.actions)
        idx = 0

        # epsilon greedily
        if random.random() <= self.epsilon:
            idx = random.randrange(self.actions)
            action[idx] = 1
        else:
            idx = np.argmax(Q_val)
            action[idx] = 1

        # change episilon
        if self.epsilon > self.finep and self.epoch > self.observe:
            self.epsilon -= (1 - self.finep) / 500000

        return action

最终处理并输出结果

    def run(self):
        # initialize
        # discount factor 0.99
        ag = DQN(0.99, 0, 1, 0.001, 50000, 32, 4)
        g = game.gameState()
        a_0 = np.array([1, 0, 0, 0])
        s_0, r_0, d = g.frameStep(a_0)
        s_0 = cv2.cvtColor(cv2.resize(s_0, (84, 84)), cv2.COLOR_BGR2GRAY)
        _, s_0 = cv2.threshold(s_0, 1, 255, cv2.THRESH_BINARY)
        ag.initState(s_0)
        while True:
            a = ag.getAction()
            s_t1, r, done = g.frameStep(a)
            s_t1 = self.screen_handle(s_t1)
            ts, qv = ag.addReplay(s_t1, a, r, done)
            if done == True:
                sc, ep = g.retScore()
                print("Epoch: {}  Q-Value: {:.3f}  Episode: {}  Score: {}".format(ts,qv,ep,sc))
            else:
                print("Epoch: {}  Q-Value: {:.3f}".format(ts,qv))

Thank You !

Changelog

2018-01-10 创建