{"id":567,"date":"2020-06-02T07:06:00","date_gmt":"2020-06-02T07:06:00","guid":{"rendered":"https:\/\/tensor.agenthub.uk\/?p=567"},"modified":"2024-05-16T07:54:48","modified_gmt":"2024-05-16T07:54:48","slug":"pytorch%e5%ae%9e%e7%8e%b0policy-gradient","status":"publish","type":"post","link":"https:\/\/tensorzen.blog\/?p=567","title":{"rendered":"PyTorch\u5b9e\u73b0Policy Gradient"},"content":{"rendered":"\n<p>\u5148\u6765\u56de\u5fc6\u4e00\u4e0b\u51e0\u4e2a\u53d8\u91cf\u7684\u5b9a\u4e49\uff0cPolicy Gradient\u7684\u5173\u952e\u662f\u901a\u8fc7Gradient\u6765\u66f4\u65b0Policy<\/p>\n\n\n\n<p>$$\\theta_{k+1} = \\theta_{k} + a \\nabla _{\\theta}J(\\pi_{\\theta})|_{\\theta_k}$$<\/p>\n\n\n\n<p>\u5176\u4e2d$\\pi_{\\theta}$\u662f\u53c2\u6570\u8bdd\u7684policy\uff0c$\\theta$\u662f\u5b83\u7684\u7cfb\u6570\uff0c$J(\\pi_{\\theta})$\u7528\u6765\u8861\u91cf\u5f53\u524dpolicy $\\pi_{\\theta}$\u7684\u6027\u80fd\uff0c\u54b1\u4eec\u8fd9\u91cc\u7528$\\pi_{\\theta}$\u7684\u671f\u671b\u6536\u76ca$E_{\\tau \\sim \\pi_{\\theta}}[R(\\tau)]$\u4f5c\u4e3apolicy\u7684\u6027\u80fd\uff0c$R(\\tau)$\u8868\u793a\u4e00\u5c40\u6e38\u620f\u7684\u6536\u76ca,$\\tau \\sim \\pi_{\\theta}$\u8868\u793a\u662f\u5728\u5f53\u524dpolicy $\\pi_{\\theta}$\u4e0b\u3002<\/p>\n\n\n\n<p>$\\nabla _{\\theta}J(\\pi_{\\theta})$\u7b49\u4e8e\u4e0b\u9762\u8fd9\u4e00\u4e32<\/p>\n\n\n\n<p>$$\\nabla _{\\theta}J(\\pi_{\\theta}) = E_{\\tau \\sim \\pi_{\\theta}} \\left [ \\sum_{t=0}^{T} \\nabla_{\\theta} \\log \\pi_{\\theta} (a_t | s_t)R(\\tau)\\right ]$$<\/p>\n\n\n\n<p>\u4e0a\u4e00\u7bc7\u6587\u7ae0\u4ecb\u7ecd\u4e86\u600e\u4e48\u5f97\u5230\u8fd9\u4e2a\u516c\u5f0f\u7684\uff0c\u5b9e\u9645\u4e0a\u5e76\u4e0d\u5f71\u54cd\u6211\u4eec\u7684\u5b9e\u73b0\u3002<\/p>\n\n\n\n<p>\u6e38\u620f\u73af\u5883\u6211\u4eec\u4f7f\u7528\u6700\u7ecf\u5178\u7684\u63a8\u5c0f\u8f66<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"318\" src=\"https:\/\/tensor.agenthub.uk\/wp-content\/uploads\/2024\/05\/image-1-1024x318.png\" alt=\"\" class=\"wp-image-576\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-1-1024x318.png 1024w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-1-300x93.png 300w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-1-768x239.png 768w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-1.png 1288w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u8f66\u5b50\u4e0a\u7684\u6746\u513f\u4f1a\u56f4\u7ed5\u84dd\u8272\u70b9\u81ea\u7531\u8f6c\u52a8\uff0c\u5230\u8fbe\u4e00\u5b9a\u89d2\u5ea6\u6e38\u620f\u7ed3\u675f\uff0c\u6211\u4eec\u63a7\u5236\u8f66\u5b50\u5728\u6c34\u5e73\u7ebf\u4e0a\u5de6\u53f3\u79fb\u52a8\u6765\u4f7f\u6746\u513f\u5e73\u8861\uff0c\u5927\u6982\u5c31\u50cf\u6742\u6280\u6f14\u5458\u62ff\u4e00\u6839\u7af9\u7aff\u653e\u5728\u98df\u6307\u4e0a\u8ba9\u7af9\u7aff\u4e00\u76f4\u4e0d\u6389\u5730\u4e0a\u3002\u00a0<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"328\" src=\"https:\/\/tensor.agenthub.uk\/wp-content\/uploads\/2024\/05\/image-2-1024x328.png\" alt=\"\" class=\"wp-image-577\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-2-1024x328.png 1024w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-2-300x96.png 300w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-2-768x246.png 768w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-2.png 1322w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u4ee3\u7801\u662f\u57fa\u4e8eOpenAI\u7684Spinning Up\u5b9e\u73b0\u7684\uff0c\u6240\u4ee5\u5982\u679c\u4f60\u6709\u5174\u8da3\u53ef\u4ee5\u76f4\u63a5\u770b\u539f\u6587\uff1ahttps:\/\/spinningup.openai.com\/en\/latest\/spinningup\/rl_intro3.html<\/p>\n\n\n\n<p>\u5148\u521b\u5efa\u73af\u5883\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#f6f6f4;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#282A36\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"import gym\nenv = gym.make('CartPole-v1')\" style=\"color:#f6f6f4;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki dracula-soft\" style=\"background-color: #282A36\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #F286C4\">import<\/span><span style=\"color: #F6F6F4\"> gym<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">env <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> gym.make(<\/span><span style=\"color: #DEE492\">&#39;<\/span><span style=\"color: #E7EE98\">CartPole-v1<\/span><span style=\"color: #DEE492\">&#39;<\/span><span style=\"color: #F6F6F4\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u6e38\u620f\u73af\u5883\u7684\u7684observation\u75314\u7ef4\u6570\u7ec4\u8868\u793a\uff0c\u81f3\u4e8e\u6bcf\u7ef4\u8868\u793a\u4ec0\u4e48\u610f\u601d\u5e76\u4e0d\u91cd\u8981\uff0c\u53ef\u4ee5\u91c7\u53d6\u7684\u52a8\u4f5c\u53ea\u6709\u5411\u5de6\u3001\u5411\u53f3\u6240\u4ee5\u52a8\u4f5c\u662f2\u7ef4\u7684\u3002policy\u662f\u4e00\u4e2a\u7b80\u5355\u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u628a\u5b83\u5b9a\u4e49\u6210\u5982\u4e0b\u5f62\u5f0f<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#f6f6f4;--cbp-line-number-width:calc(2 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#282A36\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"import torch\nimport torch.nn as nn\nclass PolicyNet(nn.Module):\n    def __init__(self, obs_space, act_space):\n        super().__init__()\n        self.linear1 = nn.Linear(obs_space, 32)\n        self.activate = nn.Tanh()\n        self.linear2 = nn.Linear(32, act_space)\n    \n    def forward(self, x):\n        x = self.linear1(x)\n        x = self.activate(x)\n        x = self.linear2(x)\n        return x\" style=\"color:#f6f6f4;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki dracula-soft\" style=\"background-color: #282A36\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #F286C4\">import<\/span><span style=\"color: #F6F6F4\"> torch<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F286C4\">import<\/span><span style=\"color: #F6F6F4\"> torch.nn <\/span><span style=\"color: #F286C4\">as<\/span><span style=\"color: #F6F6F4\"> nn<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F286C4\">class<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #97E1F1\">PolicyNet<\/span><span style=\"color: #F6F6F4\">(<\/span><span style=\"color: #97E1F1; font-style: italic\">nn<\/span><span style=\"color: #F6F6F4\">.<\/span><span style=\"color: #97E1F1; font-style: italic\">Module<\/span><span style=\"color: #F6F6F4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #F286C4\">def<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE\">__init__<\/span><span style=\"color: #F6F6F4\">(<\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">, <\/span><span style=\"color: #FFB86C; font-style: italic\">obs_space<\/span><span style=\"color: #F6F6F4\">, <\/span><span style=\"color: #FFB86C; font-style: italic\">act_space<\/span><span style=\"color: #F6F6F4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #97E1F1; font-style: italic\">super<\/span><span style=\"color: #F6F6F4\">().<\/span><span style=\"color: #BF9EEE\">__init__<\/span><span style=\"color: #F6F6F4\">()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">.linear1 <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> nn.Linear(obs_space, <\/span><span style=\"color: #BF9EEE\">32<\/span><span style=\"color: #F6F6F4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">.activate <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> nn.Tanh()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">.linear2 <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> nn.Linear(<\/span><span style=\"color: #BF9EEE\">32<\/span><span style=\"color: #F6F6F4\">, act_space)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #F286C4\">def<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #62E884\">forward<\/span><span style=\"color: #F6F6F4\">(<\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">, <\/span><span style=\"color: #FFB86C; font-style: italic\">x<\/span><span style=\"color: #F6F6F4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        x <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">.linear1(x)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        x <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">.activate(x)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        x <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE; font-style: italic\">self<\/span><span style=\"color: #F6F6F4\">.linear2(x)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #F286C4\">return<\/span><span style=\"color: #F6F6F4\"> x<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u7f51\u7edc\u7684\u8f93\u51fa\u6ca1\u6709\u63a5softmax\uff0c\u540e\u9762\u5728\u6784\u9020\u5206\u5e03\u7684\u65f6\u5019\u8981\u6ce8\u610f\u4e0b\u3002\u5b9e\u4f8b\u5316\u65f6\u4f20\u5165\u4e24\u4e2a\u53c2\u6570\uff0c\u7b2c\u4e00\u4e2a\u53c2\u6570\u662fobservation\u7684\u7ef4\u5ea6\uff0c\u7b2c\u4e8c\u4e2a\u53c2\u6570\u662faction\u7684\u7ef4\u5ea6\uff0c\u5f53\u524d\u73af\u5883\u5206\u522b\u662f4\u548c2\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#f6f6f4;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#282A36\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"policy_net = PolicyNet(obs_dim, n_act)\" style=\"color:#f6f6f4;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki dracula-soft\" style=\"background-color: #282A36\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #F6F6F4\">policy_net <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> PolicyNet(obs_dim, n_act)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u6709\u4e86\u8fd9\u4e2apolicy_net\u6211\u4eec\u5c31\u53ef\u4ee5\u6839\u636e\u5f53\u524d\u7684obs\u53d6\u6837\u4e00\u4e2a\u52a8\u4f5c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#f6f6f4;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#282A36\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"def get_action(obs):\n    logits_output = policy_net(obs)\n    act_distribution = Categorical(logits=logits_output)\n    return act_distribution.sample().item()\" style=\"color:#f6f6f4;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki dracula-soft\" style=\"background-color: #282A36\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #F286C4\">def<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #62E884\">get_action<\/span><span style=\"color: #F6F6F4\">(<\/span><span style=\"color: #FFB86C; font-style: italic\">obs<\/span><span style=\"color: #F6F6F4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    logits_output <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> policy_net(obs)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    act_distribution <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> Categorical(<\/span><span style=\"color: #FFB86C; font-style: italic\">logits<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">logits_output)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #F286C4\">return<\/span><span style=\"color: #F6F6F4\"> act_distribution.sample().item()<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u63a5\u4e0b\u6765\u770bgradient\u7684\u8ba1\u7b97\uff1a<\/p>\n\n\n\n<p>$$\\nabla _{\\theta}J(\\pi_{\\theta}) = E_{\\tau \\sim \\pi_{\\theta}} \\left [ \\sum_{t=0}^{T} \\nabla_{\\theta} \\log \\pi_{\\theta} (a_t | s_t)R(\\tau)\\right ]$$<\/p>\n\n\n\n<p>\u4e00\u79cd\u76f4\u89c9\u7684\u7406\u89e3\uff0c\u518d\u67d0\u4e2a\u72b6\u6001$s_t$\u4e0b\u6267\u884c\u67d0\u4e2a\u52a8\u4f5c$a_t$\uff0c\u5982\u679c\u8fd9\u4e2a\u52a8\u4f5c\u53ef\u4ee5\u5446\u4e86\u6bd4\u8f83\u9ad8\u7684\u6536\u76ca\u6211\u4eec\u5c31\u5e94\u8be5\u63d0\u5347\u5b83\uff08\u52a0\u5927\u8be5\u72b6\u6001\u6267\u884c\u8be5\u52a8\u4f5c\u7684\u6982\u7387\uff09\u8fd9\u6837\u4e0b\u6b21\u53d6\u6837\u5c31\u66f4\u6709\u53ef\u80fd\u53d6\u5230\u8fd9\u4e2a\u52a8\u4f5c\u4e86\u3002\u5982\u4f55\u6765\u8bc4\u4ef7$s_t$\u4e0b\u6267\u884c\u52a8\u4f5c$a_t$\u7684\u6536\u76ca\u597d\u574f\u5462\uff1f\u6211\u4eec\u53ef\u4ee5\u7528\u672c\u5c40\u6e38\u620f\u7684\u603b\u6536\u76ca\u6765\u8861\u91cf\uff0c\u53ef\u80fd\u4e0d\u597d\uff0c\u4f46\u662f\u80fd\u7528\u3002\u4e3e\u4e2a\u4f8b\u5b50\uff0c\u5982\u679c\u6211\u4eec\u8ba1\u7b97$s_0$\u5230$s_t$\u4e2d\u95f4\u6709\u591a\u5c11\u6b65\u6765\u4f5c\u4e3a\u8fd9\u5c40\u6e38\u620f\u7684\u6536\u76ca\uff0c\u4f7f\u7528$\\pi_{\\theta}$\u6267\u73a9\u4e86\u4e00\u5c40\u53ef\u80fd\u662f\u8fd9\u6837\u7684<\/p>\n\n\n\n<p>$$s_0, a_0^1, s_1^1,a_1^1,s_2^1,a_2^1,s_3^1,a_3^1,s_4^1,a_4^1,s_T$$<\/p>\n\n\n\n<p>\u4e8e\u662f\u8fd9\u4e00\u5c40\u7684\u6536\u76ca\u662f5,\u4e0a\u68071\u8868\u793a\u8fd9\u662f\u7b2c\u4e00\u5c40\u7684\u72b6\u6001\u8f6c\u79fb\u548c\u52a8\u4f5c\u9009\u62e9\uff0c\u8fd8\u662f\u7528\u8fd9\u4e2apolicy\u7b2c\u4e8c\u5c40\u662f\u8fd9\u6837\u7684<\/p>\n\n\n\n<p>$$s_0, a_0^2, s_1^2,a_1^2,s_2^2,a_2^2,s_3^2,a_3^2,s_4^2,a_4^2,s_5^2,a_5^2,s_6^2,a_6^2,s_7^2,a_7^2,s_T$$<\/p>\n\n\n\n<p>\u90a3\u4e48\u7b2c\u4e8c\u5c40\u7684\u6536\u76ca\u662f8 ,\u6240\u4ee5\u7b2c\u4e8c\u5c40\u6253\u7684\u6bd4\u8f83\u597d\uff0c\u5e94\u8be5\u5f3a\u5316\u7b2c\u4e8c\u5c40\u4e2d\u6bcf\u4e2a\u72b6\u6001\u4e0b\u6267\u884c\u7684\u90a3\u4e2a\u52a8\u4f5c\u7684\u6982\u7387\uff0c\u8ba1\u7b97\u68af\u5ea6\u7684\u65f6\u5019\u7b2c\u4e8c\u5c40\u5199\u6210<\/p>\n\n\n\n<p>$$\\nabla _{\\theta} \\log \\pi_{\\theta}(s_0|a_0^2) \\times 8 + \\nabla _{\\theta} \\log \\pi_{\\theta}(s_1|a_1^2) \\times 8 + \\nabla _{\\theta} \\log \\pi_{\\theta}(s_2|a_2^2) \\times 8 &#8230;$$<\/p>\n\n\n\n<p>\u800c\u7b2c\u4e00\u5c40\u662f<\/p>\n\n\n\n<p>$$\\nabla _{\\theta} \\log \\pi_{\\theta}(s_0|a_0^1) \\times 5 + \\nabla _{\\theta} \\log \\pi_{\\theta}(s_1|a_1^1) \\times 5 + \\nabla _{\\theta} \\log \\pi_{\\theta}(s_2|a_2^1) \\times 5 &#8230;$$<\/p>\n\n\n\n<p>\u663e\u7136\u7b2c\u4e8c\u5c40\u4e58\u4ee58\u7684\u90a3\u90e8\u5206\u5728\u68af\u5ea6\u66f4\u65b0\u4e2d\u66f4\u6709\u4f18\u52bf\uff0c\u4e8e\u662f\u7b2c\u4e8c\u5c40\u6267\u884c\u7684\u90a3\u4e9b\u52a8\u4f5c\u90fd\u5f97\u5230\u4e86\u5f3a\u5316\uff0c\u5f53\u7136\u8fd9\u4e2a\u53ea\u662f\u80fd\u7528\u800c\u5df2\u8fd8\u4e0d\u591f\u597d\uff0c\u6700\u7406\u60f3\u7684\u80af\u5b9a\u662f$Q_{\\pi}(s,a)$\u4f46\u662f\u8fd9\u4e2a\u4e0d\u597d\u7b97\u554a\uff0c\u5148\u5c06\u5c31\u7740\u7528\u5427\uff5e\u65e2\u7136\u6211\u4eec\u8fbe\u6210\u5171\u8bc6\uff0c\u5c31\u770b\u5728pytorch\u4e2d\u600e\u4e48\u5b9e\u73b0\u4e0a\u8ff0\u903b\u8f91\u4e86\u3002<\/p>\n\n\n\n<p>\u516c\u5f0f\u7ed9\u51fa\u7684gradient\u6211\u4eec\u4e0d\u7528\u5b83\uff0c\u56e0\u4e3aPyTorch\u8fd9\u4e00\u7c7b\u7684\u8ba1\u7b97\u56fe\u6846\u67b6\u7684\u7f16\u7a0b\u903b\u8f91\u662f\u6211\u4eec\u63d0\u4f9bobjective function\u5b83\u7ed9\u81ea\u52a8\u7b97\u68af\u5ea6\uff0c\u4e8e\u662f\u6211\u4eec\u4e2agradient\u8be5\u5199\u6210objective\u5f62\u5f0f<\/p>\n\n\n\n<p>$$J(\\pi_{\\theta}) = E_{\\tau \\sim \\pi_{\\theta}} [\\log \\pi_{\\theta} (a_t|s_t) R(\\tau)]$$<\/p>\n\n\n\n<p>\u54b1\u4e0d\u5173\u5fc3\u8fd9\u4e2aobjective\u5230\u5e95\u662f\u5565\u610f\u601d\uff0c\u5b83\u53ea\u662f\u7528\u6765\u5728PyTorch\u4e2d\u8ba1\u7b97\u68af\u5ea6\u6765\u53cd\u5411\u4f20\u64ad\u7528\u7684<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#f6f6f4;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#282A36\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"def compute_loss(obs, act, weights):\n    logits_output = policy_net(obs)\n    actions_dist = Categorical(logits=logits_output)\n    logp = actions_dist.log_prob(act)\n    return -1 * (logp * weights).mean()\" style=\"color:#f6f6f4;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki dracula-soft\" style=\"background-color: #282A36\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #F286C4\">def<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #62E884\">compute_loss<\/span><span style=\"color: #F6F6F4\">(<\/span><span style=\"color: #FFB86C; font-style: italic\">obs<\/span><span style=\"color: #F6F6F4\">, <\/span><span style=\"color: #FFB86C; font-style: italic\">act<\/span><span style=\"color: #F6F6F4\">, <\/span><span style=\"color: #FFB86C; font-style: italic\">weights<\/span><span style=\"color: #F6F6F4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    logits_output <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> policy_net(obs)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    actions_dist <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> Categorical(<\/span><span style=\"color: #FFB86C; font-style: italic\">logits<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">logits_output)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    logp <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> actions_dist.log_prob(act)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #F286C4\">return<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #F286C4\">-<\/span><span style=\"color: #BF9EEE\">1<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #F286C4\">*<\/span><span style=\"color: #F6F6F4\"> (logp <\/span><span style=\"color: #F286C4\">*<\/span><span style=\"color: #F6F6F4\"> weights).mean()<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u63a5\u6536\u4e09\u4e2a\u53c2\u6570\uff0c\u7b2c\u4e00\u4e2a\u662f\u5f53\u524d$\\pi_{\\theta}$\u73a9\u7684\u90a3\u597d\u591a\u597d\u591a\u5c40\u6e38\u620f\u7684\u6240\u6709observation\uff0c\u7b2c\u4e8c\u4e2a\u662f\u8fd9\u4e9bobservation\u51fa\u73b0\u65f6\u6240\u91c7\u53d6\u7684\u52a8\u4f5c\uff0c\u7b2c\u4e09\u4e2a\u662f\u5bf9\u8fd9\u4e2a\u52a8\u4f5c\u7684\u8bc4\u4ef7\uff0c\u4e5f\u5c31\u662f\u6240\u5728\u7684\u90a3\u5c40\u6e38\u620f\u7684\u6536\u76ca\u3002\u8fd8\u662f\u4e0a\u9762\u7684\u4f8b\u5b50\uff0c\u8f93\u5165\u7684\u5185\u5bb9\u5c31\u662f\uff1a<\/p>\n\n\n\n<p>$$obs = s_0,s_1^1,s_2^1,s_3^1,s_4^1,s_0,s_1^2,s_2^2,s_3^2,s_4^2,s_5^2,s_6^2,s_7^2$$<\/p>\n\n\n\n<p>$$act=a_0^1,a_1^1,a_2^1,a_3^1,a_4^1,a_0^2,a_1^2,a_2^2,a_3^2,a_4^2,a_5^2,a_6^2,a_7^2$$<\/p>\n\n\n\n<p>$$weights=5,5,5,5,5,8,8,8,8,8,8,8,8$$<\/p>\n\n\n\n<p>\u7ed3\u675f\u72b6\u6001$s_T$\u7528\u4e0d\u4e0a\uff0c\u9010\u884c\u89e3\u91ca\u4e0b<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li>\u9996\u5148\u4f7f\u7528policy_net\u83b7\u53d6\u6bcf\u4e2aobservation\u4e0b\u4e24\u4e2a\u52a8\u4f5c\u7684logits(\u7ecf\u8fc7softmax\u540e\u5f97\u5230\u6982\u7387)<\/li>\n\n\n\n<li>\u5b9a\u4e49Categorical\u5206\u5e03\uff0c\u8fd9\u91cc\u8981\u6307\u5b9a\u53c2\u6570logits,\u9ed8\u8ba4\u53c2\u6570\u63a5\u6536\u7684\u662f\u6982\u7387<\/li>\n\n\n\n<li>\u83b7\u5f97\u5f53\u65f6\u6e38\u620f\u65f6\u52a8\u4f5c\u7684log probability<\/li>\n\n\n\n<li>\u6700\u540e\u8ba1\u7b97\u5747\u503c\u7684\u65f6\u5019\u7ed9\u52a0\u4e86\u4e00\u4e2a\u8d1f\u53f7\uff0c\u8fd9\u662f\u56e0\u4e3apytorch\u7684\u4f18\u5316\u5668\u9ed8\u8ba4\u662f\u7528\u6765\u505agradient descent\u7684\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u63a5\u4e0b\u6765\u7684\u5de5\u4f5c\u5c31\u662f\u8ba9\u5de5\u7a0b\u5faa\u73af\u8d77\u6765\uff0c\u8fed\u4ee3\u66f4\u65b0policy\u3002\u5927\u81f4\u8fc7\u7a0b\u5c31\u662f\uff0c\u4f7f\u7528\u5f53\u524d\u7684policy\u00a0\uff0c\u73a9\u5f88\u591a\u5f88\u591a\u5c40\u6e38\u620f\uff0c\u8fd9\u91cc\u662f\u901a\u8fc7step\u6765\u9650\u5236\u7684\u73a9\u591a\u4e45\uff0c\u6bd4\u5982\u73a95000\u6b65\uff0c\u6e38\u620f\u8fc7\u7a0b\u4e2d\u91c7\u7528\u7684act\u7531$\\pi_{\\theta}$\u51b3\u5b9a\uff0c\u8bb0\u5f55\u4e0b5000\u6b65\u7684[observation, action, weight] \u4ea4\u7ed9compute_loss()\u51fd\u6570\u8ba1\u7b97\u51faloss\uff0c\u4f7f\u7528pytorch\u63d0\u4f9b\u7684\u4f18\u5316\u5668\u6765\u66f4\u65b0policy\u3002\u6bcf\u66f4\u65b0\u4e00\u6b21policy\u5c31\u5f97\u91cd\u65b0\u53d6\u68375000\u6b65\uff0c\u8fd9\u4e2a\u5927\u6982\u662f\u8981\u5f3a\u8c03\u4e00\u4e0b\u7684\u3002\u8fd9\u4e2a\u53d6\u6837\u7684\u6b65\u6570\u4e0d\u80fd\u592a\u5c11\uff0c\u8981\u4e0d\u7136\u5f88\u591aobservation\u548caction\u90fd\u6ca1\u6709\u53d6\u6837\u5230\u3002\u6211\u628a\u4ee3\u7801\u8d34\u51fa\u6765\uff0cSpinning Up\u5df2\u7ecf\u7ed9\u4ee3\u7801\u589e\u52a0\u4e86\u6bd4\u8f83\u8be6\u7ec6\u7684\u6ce8\u91ca\uff0c\u65b0\u7248\u672c\u7684gym\u4f1a\u6709\u9002\u914d\u95ee\u9898\uff0c\u8fd9\u91cc\u4fee\u6539\u4e86\u4e0b\u3002<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#f6f6f4;--cbp-line-number-width:calc(2 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#282A36\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"# for training policy\ndef train_one_epoch():\n    # make some empty lists for logging.\n    batch_obs = []          # for observations\n    batch_acts = []         # for actions\n    batch_weights = []      # for R(tau) weighting in policy gradient\n    batch_rets = []         # for measuring episode returns\n    batch_lens = []         # for measuring episode lengths\n\n    # reset episode-specific variables\n    obs, info = env.reset()       # first obs comes from starting distribution\n    done = False            # signal from environment that episode is over\n    # \u7528\u6765\u88c5\u586b\u6bcf\u4e00\u6b65\u7684reward\uff0c\u4e00\u5c40\u6e38\u620f\u6253\u5b8c\u4e4b\u540esum\u4e00\u5c31\u662f\u8fd9\u5c40\u6e38\u620f\u7684\u6536\u76ca(\u4e0a\u9762\u7684\u4f8b\u5b505\u548c8)\n    ep_rews = []            # list for rewards accrued throughout ep \n\n    # render first episode of each epoch\n    finished_rendering_this_epoch = False\n\n    # collect experience by acting in the environment with current policy\n    # \u53d6\u6837\u5faa\u73af\n    while True:\n\n        # rendering\n        if (not finished_rendering_this_epoch) and render:\n            env.render()\n\n        # save obs\n        batch_obs.append(obs.copy())\n\n        # act in the environment\n        act = get_action(torch.as_tensor(obs, dtype=torch.float32))\n        # old version\n        # obs, rew, done, _ = env.step(act)\n  # version 0.26.2\n  obs, reward, done, truncated, info = env.step(act)\n\n        # save action, reward\n        batch_acts.append(act)\n        ep_rews.append(rew)\n\n        if done:\n            # if episode is over, record info about episode\n            # \u672c\u5c40\u6e38\u620f\u5df2\u7ecf\u6253\u5b8c\u4e86\uff0c \u5f00\u59cb\u7edf\u8ba1\u6548\u679c\n            ep_ret, ep_len = sum(ep_rews), len(ep_rews)\n            batch_rets.append(ep_ret)\n            batch_lens.append(ep_len)\n\n            # the weight for each logprob(a|s) is R(tau)\n            batch_weights += [ep_ret] * ep_len\n\n            # reset episode-specific variables\n            obs, done, ep_rews = env.reset(), False, []\n\n            # won't render again this epoch\n            finished_rendering_this_epoch = True\n\n            # end experience loop if we have enough of it\n            # \u5982\u679c\u5df2\u7ecf\u5230\u53d6\u6837\u7684step\u6570\u91cf\uff0c\u9000\u51fa\u53d6\u6837\u5faa\u73af\uff0c\u66f4\u65b0policy\n            if len(batch_obs) &gt; batch_size:\n                break\n\n    # take a single policy gradient update step\n    # \u66f4\u65b0policy\n    optimizer.zero_grad()\n    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),\n                              act=torch.as_tensor(batch_acts, dtype=torch.int32),\n                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)\n                              )\n    batch_loss.backward()\n    optimizer.step()\n    return batch_loss, batch_rets, batch_lens\" style=\"color:#f6f6f4;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki dracula-soft\" style=\"background-color: #282A36\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #7B7F8B\"># for training policy<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F286C4\">def<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #62E884\">train_one_epoch<\/span><span style=\"color: #F6F6F4\">():<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># make some empty lists for logging.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_obs <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> []          <\/span><span style=\"color: #7B7F8B\"># for observations<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_acts <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> []         <\/span><span style=\"color: #7B7F8B\"># for actions<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_weights <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> []      <\/span><span style=\"color: #7B7F8B\"># for R(tau) weighting in policy gradient<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_rets <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> []         <\/span><span style=\"color: #7B7F8B\"># for measuring episode returns<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_lens <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> []         <\/span><span style=\"color: #7B7F8B\"># for measuring episode lengths<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># reset episode-specific variables<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    obs, info <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> env.reset()       <\/span><span style=\"color: #7B7F8B\"># first obs comes from starting distribution<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    done <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE\">False<\/span><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># signal from environment that episode is over<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># \u7528\u6765\u88c5\u586b\u6bcf\u4e00\u6b65\u7684reward\uff0c\u4e00\u5c40\u6e38\u620f\u6253\u5b8c\u4e4b\u540esum\u4e00\u5c31\u662f\u8fd9\u5c40\u6e38\u620f\u7684\u6536\u76ca(\u4e0a\u9762\u7684\u4f8b\u5b505\u548c8)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    ep_rews <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> []            <\/span><span style=\"color: #7B7F8B\"># list for rewards accrued throughout ep <\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># render first episode of each epoch<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    finished_rendering_this_epoch <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE\">False<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># collect experience by acting in the environment with current policy<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># \u53d6\u6837\u5faa\u73af<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #F286C4\">while<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE\">True<\/span><span style=\"color: #F6F6F4\">:<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #7B7F8B\"># rendering<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #F286C4\">if<\/span><span style=\"color: #F6F6F4\"> (<\/span><span style=\"color: #F286C4\">not<\/span><span style=\"color: #F6F6F4\"> finished_rendering_this_epoch) <\/span><span style=\"color: #F286C4\">and<\/span><span style=\"color: #F6F6F4\"> render:<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            env.render()<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #7B7F8B\"># save obs<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        batch_obs.append(obs.copy())<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #7B7F8B\"># act in the environment<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        act <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> get_action(torch.as_tensor(obs, <\/span><span style=\"color: #FFB86C; font-style: italic\">dtype<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.float32))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #7B7F8B\"># old version<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #7B7F8B\"># obs, rew, done, _ = env.step(act)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">  <\/span><span style=\"color: #7B7F8B\"># version 0.26.2<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">  obs, reward, done, truncated, info <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> env.step(act)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #7B7F8B\"># save action, reward<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        batch_acts.append(act)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        ep_rews.append(rew)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">        <\/span><span style=\"color: #F286C4\">if<\/span><span style=\"color: #F6F6F4\"> done:<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># if episode is over, record info about episode<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># \u672c\u5c40\u6e38\u620f\u5df2\u7ecf\u6253\u5b8c\u4e86\uff0c \u5f00\u59cb\u7edf\u8ba1\u6548\u679c<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            ep_ret, ep_len <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #97E1F1\">sum<\/span><span style=\"color: #F6F6F4\">(ep_rews), <\/span><span style=\"color: #97E1F1\">len<\/span><span style=\"color: #F6F6F4\">(ep_rews)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            batch_rets.append(ep_ret)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            batch_lens.append(ep_len)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># the weight for each logprob(a|s) is R(tau)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            batch_weights <\/span><span style=\"color: #F286C4\">+=<\/span><span style=\"color: #F6F6F4\"> [ep_ret] <\/span><span style=\"color: #F286C4\">*<\/span><span style=\"color: #F6F6F4\"> ep_len<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># reset episode-specific variables<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            obs, done, ep_rews <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> env.reset(), <\/span><span style=\"color: #BF9EEE\">False<\/span><span style=\"color: #F6F6F4\">, []<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># won&#39;t render again this epoch<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            finished_rendering_this_epoch <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #BF9EEE\">True<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># end experience loop if we have enough of it<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #7B7F8B\"># \u5982\u679c\u5df2\u7ecf\u5230\u53d6\u6837\u7684step\u6570\u91cf\uff0c\u9000\u51fa\u53d6\u6837\u5faa\u73af\uff0c\u66f4\u65b0policy<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">            <\/span><span style=\"color: #F286C4\">if<\/span><span style=\"color: #F6F6F4\"> <\/span><span style=\"color: #97E1F1\">len<\/span><span style=\"color: #F6F6F4\">(batch_obs) <\/span><span style=\"color: #F286C4\">&gt;<\/span><span style=\"color: #F6F6F4\"> batch_size:<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">                <\/span><span style=\"color: #F286C4\">break<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># take a single policy gradient update step<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #7B7F8B\"># \u66f4\u65b0policy<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    optimizer.zero_grad()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_loss <\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\"> compute_loss(<\/span><span style=\"color: #FFB86C; font-style: italic\">obs<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.as_tensor(batch_obs, <\/span><span style=\"color: #FFB86C; font-style: italic\">dtype<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.float32),<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">                              <\/span><span style=\"color: #FFB86C; font-style: italic\">act<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.as_tensor(batch_acts, <\/span><span style=\"color: #FFB86C; font-style: italic\">dtype<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.int32),<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">                              <\/span><span style=\"color: #FFB86C; font-style: italic\">weights<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.as_tensor(batch_weights, <\/span><span style=\"color: #FFB86C; font-style: italic\">dtype<\/span><span style=\"color: #F286C4\">=<\/span><span style=\"color: #F6F6F4\">torch.float32)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">                              )<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    batch_loss.backward()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    optimizer.step()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #F6F6F4\">    <\/span><span style=\"color: #F286C4\">return<\/span><span style=\"color: #F6F6F4\"> batch_loss, batch_rets, batch_lens<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u4ee5\u4e0a\u5c31\u662f\u6700\u6700\u5355\u7eaf\u7684policy gradient\u7684\u5b9e\u73b0\uff0c\u4e0a\u9762\u8d34\u51fa\u6765\u7684\u7f51\u9875\u4e2d\u8fd8\u544a\u8bc9\u6211\u4eec\u4e00\u4e9btricks\uff0c\u6709\u5174\u8da3\u53ef\u4ee5\u81ea\u5df1\u770b\u4e0b\u3002<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u5148\u6765\u56de\u5fc6\u4e00\u4e0b\u51e0\u4e2a\u53d8\u91cf\u7684\u5b9a\u4e49\uff0cPolicy Gradient\u7684\u5173\u952e\u662f\u901a\u8fc7Gradient\u6765\u66f4\u65b0Policy $$\\theta_{k+1} = \\theta_{k} + a \\nabla _{\\theta}J(\\pi_{\\theta})|_{\\theta_k}$$ \u5176\u4e2d$\\pi_{\\theta}$\u662f\u53c2\u6570\u8bdd\u7684policy\uff0c$\\theta$\u662f\u5b83\u7684\u7cfb\u6570\uff0c$J(\\pi_{\\theta})$\u7528\u6765\u8861\u91cf\u5f53\u524dpolicy $\\pi_{\\theta}$\u7684\u6027\u80fd\uff0c\u54b1\u4eec\u8fd9\u91cc\u7528$\\pi_{\\theta}$\u7684\u671f\u671b\u6536\u76ca$E_{\\tau \\sim \\pi_{\\theta}}[R(\\tau)]$\u4f5c\u4e3apolicy\u7684\u6027\u80fd\uff0c$R(\\tau)$\u8868\u793a\u4e00\u5c40\u6e38\u620f\u7684\u6536\u76ca,$\\tau \\sim \\pi_{\\theta}$\u8868\u793a\u662f\u5728\u5f53\u524dpolicy $\\pi_{\\theta}$\u4e0b\u3002 $\\nabla _{\\theta}J(\\pi_{\\theta})$\u7b49\u4e8e\u4e0b\u9762\u8fd9\u4e00\u4e32 $$\\nabla _{\\theta}J(\\pi_{\\theta}) = E_{\\tau \\sim \\pi_{\\theta}} \\left [ \\sum_{t=0}^{T} \\nabla_{\\theta} [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[18],"tags":[],"class_list":["post-567","post","type-post","status-publish","format-standard","hentry","category-reinforcement-learning"],"_links":{"self":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/567","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=567"}],"version-history":[{"count":13,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/567\/revisions"}],"predecessor-version":[{"id":582,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/567\/revisions\/582"}],"wp:attachment":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=567"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=567"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=567"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}