{"id":665,"date":"2021-04-28T03:23:00","date_gmt":"2021-04-28T03:23:00","guid":{"rendered":"https:\/\/tensorzen.online\/?p=665"},"modified":"2024-05-21T07:53:35","modified_gmt":"2024-05-21T07:53:35","slug":"pytorch%e5%ae%9e%e7%8e%b0%e6%9c%80%e5%a4%a7%e4%bc%bc%e7%84%b6%e4%bc%b0%e8%ae%a1","status":"publish","type":"post","link":"https:\/\/tensorzen.blog\/?p=665","title":{"rendered":"PyTorch\u5b9e\u73b0\u6700\u5927\u4f3c\u7136\u4f30\u8ba1"},"content":{"rendered":"\n<p>\u6211\u5077\u5077\u751f\u6210\u4e00\u6ce2\u6570\u636e$x$\uff0c\u751f\u62101000\u4e2a\u5427\uff0c\u8fd9\u4e9b\u6570\u636e\u662f\u4ece\u4e00\u4e2a\u5206\u5e03\u53d6\u6837\u7684\uff0c\u65e2\u7136\u662f\u5077\u5077\u751f\u6210\u7684\uff0c\u90a3\u80af\u5b9a\u4e0d\u80fd\u544a\u8bc9\u8fd9\u4e2a\u5206\u5e03\u662f\u4ec0\u4e48~~~<\/p>\n\n\n\n<p>\u806a\u660e\u7684\u4f60\u80af\u5b9a\u5148\u753b\u4e2a\u6563\u70b9\u56fe\u770b\u770b\u5927\u6982\u662f\u4ec0\u4e48\u5206\u5e03\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"import matplotlib.pyplot as plt\nplt.hist(x, bins=20)\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #286983\">import<\/span><span style=\"color: #575279\"> matplotlib<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">pyplot <\/span><span style=\"color: #286983\">as<\/span><span style=\"color: #575279\"> plt<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">plt<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">hist<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">x<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">bins<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #D7827E\">20<\/span><span style=\"color: #797593\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n<div class=\"wp-block-image\">\n<figure class=\"aligncenter size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"375\" height=\"248\" src=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-19.png\" alt=\"\" class=\"wp-image-668\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-19.png 375w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2024\/05\/image-19-300x198.png 300w\" sizes=\"auto, (max-width: 375px) 100vw, 375px\" \/><\/figure>\n<\/div>\n\n\n<p>\u6253\u773c\u4e00\u770b\u8fd9\u4e0d\u5c31\u662f\u6b63\u6001\u5206\u5e03\u5417\uff0c\u6b63\u6001\u5206\u5e03\u7684\u8bdd\u9700\u8981\u77e5\u9053\u4e24\u4e2a\u53c2\u6570$\\mu, \\sigma$\uff0c\u4e5f\u5c31\u662f\u5747\u503c\u548c\u6807\u51c6\u5dee\uff0c\u4e8e\u662f\u806a\u660e\u7684\u4f60\u4f1a\u5b9a\u4e49\u4e24\u4e2a\u53d8\u91cf\uff0c\u8fd9\u4e24\u4e2a\u53d8\u91cf\u5c31\u662f\u4f60\u8981learning\u7684\u53c2\u6570\u3002<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"mu = torch.tensor(1.0, requires_grad = True)\nsigma = torch.tensor(1.0, requires_grad = True)\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #575279\">mu <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">tensor<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #D7827E\">1.0<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">requires_grad<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #D7827E\">True<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">sigma <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">tensor<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #D7827E\">1.0<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">requires_grad<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #D7827E\">True<\/span><span style=\"color: #797593\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u4e3a\u4e86\u65b9\u4fbf\uff0c\u4f60\u53ef\u80fd\u4e0d\u60f3\u81ea\u5df1\u5b9e\u73b0SGD\u6765\u4f18\u5316\u8fd9\u4fe9\u53c2\u6570\u4e86\uff0c\u4f60\u4f1a\u9009\u62e9\u4f7f\u7528torch\u5185\u7f6e\u7684\u4f18\u5316\u5668\u6765\u4f18\u5316\u8fd9\u4fe9\u53c2\u6570\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"optimizer = torch.optim.SGD([mu, sigma], lr=2e-2)\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #575279\">optimizer <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">optim<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">SGD<\/span><span style=\"color: #797593\">([<\/span><span style=\"color: #575279\">mu<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> sigma<\/span><span style=\"color: #797593\">],<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">lr<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #D7827E\">2e-2<\/span><span style=\"color: #797593\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>lr\u662f\u5b66\u4e60\u7387\uff0c\u8bbe\u5b9a\u7684\u5c0f\u4e00\u70b9\u4f1a\u597d\u70b9\u3002pytoch\u5185\u7f6e\u7684\u5927\u591a\u6570\u5e38\u7528\u7684\u5206\u5e03\u51fd\u6570\uff0c\u4efb\u541b\u6311\u9009\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"q = torch.distributions.Normal(loc=mu, scale=sigma)\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #575279\">q <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">distributions<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">Normal<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #907AA9; font-style: italic\">loc<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\">mu<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">scale<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\">sigma<\/span><span style=\"color: #797593\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u8fd9\u6837\u5c31\u76f8\u5f53\u4e8e\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5747\u503c\u4e3a1\uff0c\u6807\u51c6\u5dee\u4e3a1\u7684\u6b63\u6001\u5206\u5e03$q(x)$\uff0c\u4f60\u7684\u76ee\u6807\u662f\u901a\u8fc7\u4f7f\u7528\u6700\u5927\u4f3c\u7136\u4f30\u8ba1\u6765\u5f97\u5230$\\mu, \\sigma$\u5148\u5199\u4e2a\u4f3c\u7136\u51fd\u6570\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"negative_log_likelihood = -1 * torch.sum(q.log_prob(x_batch))\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #575279\">negative_log_likelihood <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">-<\/span><span style=\"color: #D7827E\">1<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">*<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">sum<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">q<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">log_prob<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">x_batch<\/span><span style=\"color: #797593\">))<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>pytorch\u5e76\u6ca1\u6709\u63d0\u4f9bprob\u65b9\u6cd5\u8fd4\u56de\u6982\u7387\uff0c\u53cd\u6b63\u4f60\u4e5f\u4e0d\u4f1a\u7528\u6982\u7387\u76f8\u4e58\uff0c\u90fd\u662f\u6982\u7387\u7684\u5bf9\u6570\u76f8\u52a0\uff0c\u4e8e\u662f\u5bf9log_prob\u6c42\u548c\u5c31\u53ef\u4ee5\u4e86\u3002\u63a5\u4e0b\u6765\u8ba1\u7b97$\\mu, \\sigma$\u7684\u68af\u5ea6<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(1 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"negative_log_likelihood.backward()\noptimizer.step()\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #575279\">negative_log_likelihood<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">backward<\/span><span style=\"color: #797593\">()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">optimizer<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">step<\/span><span style=\"color: #797593\">()<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u987a\u4fbf\u8c03\u7528\u4e00\u4e0boptimizer.step()\u65b9\u6cd5\u6765\u66f4\u65b0\u4e00\u4e0b$\\mu, \\sigma$\uff0c\u6267\u884c\u591a\u6b21\u4e4b\u540e\uff0c\u5c31\u80fd\u5f97\u5230\u76ee\u6807\u4e86\u3002<\/p>\n\n\n\n<p>\u81f3\u4e8e\uff0c\u5077\u5077\u751f\u6210\u6570\u636e\u7684\u8fc7\u7a0b\uff0c\u5176\u5b9e\u662f\u7528\u5747\u503c\u4e3a-4\uff0c\u65b9\u5dee\u4e3a2\u7684\u6b63\u6001\u5206\u5e03\u751f\u6210\u4e861000\u4e2a\u6837\u672c\u3002<\/p>\n\n\n\n<p>\u4e8e\u662f\u7ecf\u8fc7\u4e0a\u8ff0\u7684\u5b66\u4e60\u8fc7\u7a0b\uff0c$\\mu, \\sigma$\u4f1a\u5728-4,2\u524d\u540e\u5f98\u5f8a\uff0c\u56e0\u4e3a\u6bd5\u7adf\u662f\u53d6\u6837\u7684\u6570\u636e\uff0c\u4e0e\u771f\u5b9e\u5206\u5e03\u8fd8\u662f\u6709\u5dee\u8ddd\u7684~<\/p>\n\n\n\n<p>\u5b8c\u6574\u4ee3\u7801\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro cbp-has-line-numbers\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.75rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;--cbp-line-number-color:#575279;--cbp-line-number-width:calc(2 * 0.6 * .75rem);line-height:1rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#faf4ed\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"import numpy as np\nfrom scipy.stats import norm\nimport torch\nimport matplotlib.pyplot as plt\n# \u751f\u6210\u6570\u636e\nx = np.random.normal(loc = -4, scale = 2, size = 1000)\nplt.hist(x, bins=20)\nx = torch.tensor(x)\n# MLE\nmu = torch.tensor(1.0, requires_grad = True)\nsigma = torch.tensor(1.0, requires_grad = True)\noptimizer = torch.optim.SGD([mu, sigma], lr=2e-2)\n\n# SGD\nidx = list(range(len(x)))\nfor epoch in range(2):\n    np.random.shuffle(idx)\n    for i in range(0,len(idx),10):\n        x_batch = x[idx[i:i+10]]\n        optimizer.zero_grad()\n        q = torch.distributions.Normal(loc=mu, scale=sigma)\n        negative_log_likelihood = -1 * torch.sum(q.log_prob(x_batch))\n        negative_log_likelihood.backward()\n        optimizer.step()\n\nprint(&quot;{},{}&quot;.format(mu.detach(), sigma.detach()))\" style=\"color:#575279;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki rose-pine-dawn\" style=\"background-color: #faf4ed\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #286983\">import<\/span><span style=\"color: #575279\"> numpy <\/span><span style=\"color: #286983\">as<\/span><span style=\"color: #575279\"> np<\/span><\/span>\n<span class=\"line\"><span style=\"color: #286983\">from<\/span><span style=\"color: #575279\"> scipy<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">stats <\/span><span style=\"color: #286983\">import<\/span><span style=\"color: #575279\"> norm<\/span><\/span>\n<span class=\"line\"><span style=\"color: #286983\">import<\/span><span style=\"color: #575279\"> torch<\/span><\/span>\n<span class=\"line\"><span style=\"color: #286983\">import<\/span><span style=\"color: #575279\"> matplotlib<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">pyplot <\/span><span style=\"color: #286983\">as<\/span><span style=\"color: #575279\"> plt<\/span><\/span>\n<span class=\"line\"><span style=\"color: #797593; font-style: italic\">#<\/span><span style=\"color: #9893A5; font-style: italic\"> \u751f\u6210\u6570\u636e<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">x <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> np<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">random<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">normal<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #907AA9; font-style: italic\">loc<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">-<\/span><span style=\"color: #D7827E\">4<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">scale<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #D7827E\">2<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">size<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #D7827E\">1000<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">plt<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">hist<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">x<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">bins<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #D7827E\">20<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">x <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">tensor<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">x<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #797593; font-style: italic\">#<\/span><span style=\"color: #9893A5; font-style: italic\"> MLE<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">mu <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">tensor<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #D7827E\">1.0<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">requires_grad<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #D7827E\">True<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">sigma <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">tensor<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #D7827E\">1.0<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">requires_grad<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #D7827E\">True<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">optimizer <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">optim<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">SGD<\/span><span style=\"color: #797593\">([<\/span><span style=\"color: #575279\">mu<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> sigma<\/span><span style=\"color: #797593\">],<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">lr<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #D7827E\">2e-2<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #797593; font-style: italic\">#<\/span><span style=\"color: #9893A5; font-style: italic\"> SGD<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">idx <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #56949F\">list<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #B4637A; font-style: italic\">range<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #B4637A; font-style: italic\">len<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">x<\/span><span style=\"color: #797593\">)))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #286983\">for<\/span><span style=\"color: #575279\"> epoch <\/span><span style=\"color: #286983\">in<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #B4637A; font-style: italic\">range<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #D7827E\">2<\/span><span style=\"color: #797593\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">    np<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">random<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">shuffle<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">idx<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">    <\/span><span style=\"color: #286983\">for<\/span><span style=\"color: #575279\"> i <\/span><span style=\"color: #286983\">in<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #B4637A; font-style: italic\">range<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #D7827E\">0<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #B4637A; font-style: italic\">len<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">idx<\/span><span style=\"color: #797593\">),<\/span><span style=\"color: #D7827E\">10<\/span><span style=\"color: #797593\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">        x_batch <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> x<\/span><span style=\"color: #797593\">[<\/span><span style=\"color: #575279\">idx<\/span><span style=\"color: #797593\">[<\/span><span style=\"color: #575279\">i<\/span><span style=\"color: #797593\">:<\/span><span style=\"color: #575279\">i<\/span><span style=\"color: #286983\">+<\/span><span style=\"color: #D7827E\">10<\/span><span style=\"color: #797593\">]]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">        optimizer<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">zero_grad<\/span><span style=\"color: #797593\">()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">        q <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">distributions<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">Normal<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #907AA9; font-style: italic\">loc<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\">mu<\/span><span style=\"color: #797593\">,<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #907AA9; font-style: italic\">scale<\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\">sigma<\/span><span style=\"color: #797593\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">        negative_log_likelihood <\/span><span style=\"color: #286983\">=<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">-<\/span><span style=\"color: #D7827E\">1<\/span><span style=\"color: #575279\"> <\/span><span style=\"color: #286983\">*<\/span><span style=\"color: #575279\"> torch<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">sum<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">q<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">log_prob<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">x_batch<\/span><span style=\"color: #797593\">))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">        negative_log_likelihood<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">backward<\/span><span style=\"color: #797593\">()<\/span><\/span>\n<span class=\"line\"><span style=\"color: #575279\">        optimizer<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">step<\/span><span style=\"color: #797593\">()<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #B4637A; font-style: italic\">print<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #EA9D34\">&quot;<\/span><span style=\"color: #286983\">{}<\/span><span style=\"color: #EA9D34\">,<\/span><span style=\"color: #286983\">{}<\/span><span style=\"color: #EA9D34\">&quot;<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">format<\/span><span style=\"color: #797593\">(<\/span><span style=\"color: #575279\">mu<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">detach<\/span><span style=\"color: #797593\">(),<\/span><span style=\"color: #575279\"> sigma<\/span><span style=\"color: #797593\">.<\/span><span style=\"color: #575279\">detach<\/span><span style=\"color: #797593\">()))<\/span><\/span><\/code><\/pre><\/div>\n","protected":false},"excerpt":{"rendered":"<p>\u83b7\u5f97\u4e00\u6279\u6570\u636e\u540e\u600e\u4e48\u4f7f\u7528\u6700\u5927\u4f3c\u7136\u4f30\u8ba1\u6765\u83b7\u5f97\u9884\u4f30\u7684\u53c2\u6570\u5462\uff1f\u7528PyTorch\u6765\u5b9e\u73b0\u4e0b\uff5e<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[4],"tags":[],"class_list":["post-665","post","type-post","status-publish","format-standard","hentry","category-machine-learning"],"_links":{"self":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/665","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=665"}],"version-history":[{"count":8,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/665\/revisions"}],"predecessor-version":[{"id":675,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/665\/revisions\/675"}],"wp:attachment":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=665"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=665"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=665"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}