Sarsa 算法

1. value-based

在前面已经介绍到了,在强化学习中,学习的目标是要找到一个策略 π\pi ^\ast,使得总体回报的期望最高。这里的回报就是状态价值函数 Vπ(s)V^{\pi}\left ( s \right ) 和动作价值函数 Qπ(s,a)Q^{\pi}\left ( s,a \right )。而基于值的方法并不直接对策略 π\pi 建模,而是先学习并优化价值函数,然后基于价值函数来推导出最优策略:

π(s)=argmaxa  Q(s,a)\pi ^\ast \left ( s \right )=\underset{a}{argmax}\;Q^\ast\left ( s,a \right )

以上便是 value-based 强化学习的核心思想。

基于值的方法主要有三大类,分别为:动态规划、蒙特卡洛方法和时序差分算法。在动态规划中,需要事先知道转移矩阵 PP 和奖励函数 rr,这两个参数就构成了模型,因此动态规划也划分到基于模型(model-based)方法。

在蒙特卡洛方法中,虽不需要事先知道模型,但是更新需要得到当前的采样结束后,才能对动作价值函数 QQ 更新,更新较慢,同时,蒙特卡洛方法的高方差也导致了整体的训练会出现不平稳的现象。针对以上的问题,在本篇中介绍的基于时序差分的方法会对问题有一定的优化。

2. 时序差分

在蒙特卡洛方法中,我们对状态价值的函数的更新,采用的是增量更新的方式:

Qπ(st,at)Qπ(st,at)+1N(st,at)(GQπ(st,at))Q^{\pi }\left ( s_t,a_t \right ) \leftarrow Q^{\pi }\left ( s_t,a_t \right )+\frac{1}{N_{\left ( s_t,a_t \right )}}\left ( G- Q^{\pi }\left ( s_t,a_t \right ) \right )

其中,GG 是统计得到的回报,从状态价值函数 VV 可知:

Vπ(st)=Eτπ[Gtst=s]V^{\pi }\left ( s_t \right )=\textbf{E}_{\tau \sim \pi }\left [ G_t\mid s_t=s \right ]

GG 是通过采用得到的具体的样本,而状态价值函数 VV 则是样本的期望,同样,还有动作价值函数 QQ

Qπ(st,at)=Eτπ[Gtst=s,at=a]Q^{\pi }\left ( s_t,a_t \right )=\textbf{E}_{\tau \sim \pi}\left [ G_t\mid s_t=s,a_t=a \right ]

根据贝尔曼期望方程,有:

Qπ(s,a)=Eτπ[Rt+γQπ(s,a)st=s,at=a]Q^{\pi }\left ( s,a \right )=\textbf{E}_{\tau \sim \pi }\left [ R_t+\gamma Q^{\pi }\left ( s',a' \right )\mid s_t=s,a_t=a \right ]

改写上述的增量更新方程:

Qπ(st,at)Qπ(st,at)+α(Rt+γQπ(s,a)Qπ(st,at))Q^{\pi }\left ( s_t,a_t \right ) \leftarrow Q^{\pi }\left ( s_t,a_t \right )+\alpha \left ( R_t+\gamma Q^{\pi }\left ( s',a' \right )- Q^{\pi }\left ( s_t,a_t \right ) \right )

其中,α\alpha 称为学习率。这便是时序差分算法的更新公式,整个更新公式不再依赖于完整的采样过程,每走一步在得到价值 RtR_t 后就能更新动作价值函数。

在时序差分算法中,代表的算法有 Sarsa 算法和 Q-Learning 算法。本文聚焦在 Sarsa 算法。

2. Sarsa 算法

2.1. 算法原理

Sarsa 算法全称 state-action-reward-state-action,这也是代表了 Sarsa 算法的整个过程[1]

Sarsa 算法的更新公式为:

Qπ(st,at)Qπ(st,at)+α(Rt+γQπ(s,a)Qπ(st,at))Q^{\pi }\left ( s_t,a_t \right ) \leftarrow Q^{\pi }\left ( s_t,a_t \right )+\alpha \left ( R_t+\gamma Q^{\pi }\left ( s',a' \right )- Q^{\pi }\left ( s_t,a_t \right ) \right )

2.2. 算法实现

还是以 Frozen Lake[2] 问题展开实验,实验环境如下:

构建 Sarsa 类,并写出完整的代码,如下:

import numpy as np
import gymnasium as gym

class Sarsa:
    def __init__(self, env):
        self.env = env # 环境
        self.state_size = env.observation_space.n # 状态空间大小
        self.action_size = env.action_space.n # 动作空间大小
        self.q_table = np.random.uniform(low=0, high=0.01, size=(self.state_size, self.action_size)) # Q-table

        # INFO: 设置超参数
        self.learning_rate = 0.01
        self.discount_rate = 0.95
        self.epsilon = 1.0
        self.max_epsilon = 1.0
        self.min_epsilon = 0.01
        self.decay_rate = 0.0005

        # 训练过程
        self.episodes = 500000
        self.max_steps = 1000

    def train(self):
        for episode in range(self.episodes):
            print(f"episode: {episode}")
            state, _ = self.env.reset()
            done = False

            # INFO: 先选择初始动作 a
            if np.random.uniform(0, 1) < self.epsilon:
                action = self.env.action_space.sample()
            else:
                action = np.argmax(self.q_table[state, :])

            for step in range(self.max_steps):
                new_state, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated # 更新 done 标志
                # INFO: 奖励
                if done:
                    if reward == 0:  # 掉坑
                        reward = -1
                # 到终点 reward=1 保持
                else:
                    reward = -0.01

                # INFO: 选择下一步的动作
                if np.random.uniform(0, 1) < self.epsilon:
                    next_action = env.action_space.sample()
                else:
                    next_action = np.argmax(self.q_table[new_state, :])

                # INFO: SARSA 更新
                self.q_table[state, action] = self.q_table[state, action] + self.learning_rate * \
                    (reward + self.discount_rate * self.q_table[new_state, next_action] - self.q_table[state, action])

                # INFO: 更新状态和动作
                state = new_state
                action = next_action

                if done:
                    break

            # INFO: 更新 epsilon
            self.epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon) * np.exp(-self.decay_rate * episode)
        
        return self.q_table

再写一段测试的代码,如下:

if __name__ == "__main__":
    # INFO: 创建环境
    env = gym.make('FrozenLake-v1', map_name="8x8", is_slippery=False)
    mc = Sarsa(env)
    q_table = mc.train()
    env.close()

    # INFO: 测试
    env_test = gym.make('FrozenLake-v1', map_name="8x8", is_slippery=False, render_mode='human')
    print("测试环境初始化完成")
    state, _ = env_test.reset()
    done = False
    step = 0
    while not done:
        action = np.argmax(q_table[state, :])
        state, reward, terminated, truncated, _ = env_test.step(action)
        done = terminated or truncated
        print(f"step={step}, action={action}, state={state}, reward={reward}")
        step += 1
    print("Test finished.")
    env_test.close()

3. 总结

Sarsa 算法是基于时序差分的强化学习算法,与动态规划算法相比,其不需要事先知道模型;与蒙特卡洛算法相比,Sarsa 算法可实现每一步都更新,无需像蒙特卡洛算法那样需要采样完整的 episode,才能更新。

另一点,我们看到在 Sarsa 算法中,是先采取下一步的动作后,才更新动作价值函数,这个策略也被称为 on-policy,也就是说更新的策略与实际的策略是同一个策略。

参考文献

[1] http://incompleteideas.net/book/RLbook2020.pdf

[2] https://gymnasium.farama.org/environments/toy_text/frozen_lake/