{"id":8,"date":"2022-01-29T10:23:44","date_gmt":"2022-01-29T10:23:44","guid":{"rendered":"https:\/\/tensor.agenthub.uk\/?p=8"},"modified":"2024-05-13T03:20:08","modified_gmt":"2024-05-13T03:20:08","slug":"multi-head-attention-%e8%ae%a1%e7%ae%97%e8%bf%87%e7%a8%8b","status":"publish","type":"post","link":"https:\/\/tensorzen.blog\/?p=8","title":{"rendered":"Multi-Head Attention \u8ba1\u7b97\u8fc7\u7a0b"},"content":{"rendered":"\n<p>Attention\u7684\u8ba1\u7b97\u8fc7\u7a0b\u76f8\u5f53\u4e8e\u5728\u4e00\u4e2ahash table\u91cc\u6839\u636ekey\u6765\u67e5\u627e\u5bf9\u5e94\u7684value\uff0c\u4e0d\u540c\u4e4b\u5904\u5728\u4e8e\u8fd9\u91cc\u4e0d\u662f\u786c\u5339\u914d\uff0c\u800c\u662f\u6839\u636equery\u548c\u6240\u6709key\u7684\u76f8\u5173\u6027\u8ddfvalue\u8ba1\u7b97\u4e86\u4e00\u4e2a\u52a0\u6743\u548c\u4f5c\u4e3a\u627e\u5230\u7684\u4fe1\u606f\u3002\u8ba1\u7b97\u516c\u5f0f\u5982\u4e0b\uff1a<\/p>\n\n\n\n<p>$$<br>\\text {attention}(Q, K, V) = \\text {softmax} \\left ( \\frac{Q\\cdot K^T}{\\sqrt{d_{k}}} \\right ) \\cdot V<br>$$<\/p>\n\n\n\n<p>\u6765\u4e00\u6b65\u6b65\u770b\u4e00\u4e0b\u8ba1\u7b97\u8fc7\u7a0b\uff0c\u6bd4\u5982\u6709\u8fd9\u4e48\u4e00\u53e5&#8221;the cat set on mat&#8221;\uff0c\u91cc\u9762\u67095\u4e2a\u8bcd\uff0c\u7b2c\u4e00\u4e2a\u8bcdcat\u8ddf\u5176\u4ed6\u5355\u8bcd\u7684\u76f8\u5173\u6027\u8868\u793a\u6210\u8fd9\u6837$w_{cat, the},w_{cat, set},w_{cat,on},w_{cat,mat}$\u7684\u76f8\u5173\u6027\uff0c\u6bcf\u4e2a\u5bf9\u8868\u793a\u4e00\u4e2a0-1\u4e4b\u95f4\u7684\u6743\u91cd\uff0c\u6240\u6709\u76f8\u5173\u6027\u52a0\u548c\u662f1\uff0c\u5047\u5982\u6bcf\u4e2a\u8bcd\u90fd\u6709\u81ea\u5df1\u786e\u5b9a\u7684value\uff0c\u90a3\u4e48\u8ba1\u7b97\u4e00\u4e0b<\/p>\n\n\n\n<p>$$\\text{Info}_{cat} = w_{cat,the}v_{the} + w_{cat, set}v_{set}$$<\/p>\n\n\n\n<p>Value\u6211\u4eec\u8ba4\u4e3a\u662f\u6bcf\u4e2a\u8bcd\u5904\u5728\u5f53\u524d\u6574\u4e2a\u53e5\u5b50\u65f6\u542b\u6709\u7684\u4fe1\u606f\uff0c\u8fd9\u4e2a\u4fe1\u606f\u91cf\u662f\u5b66\u4e60\u5230\u7684\uff0c\u7136\u540e\u5355\u8bcd\u672c\u8eab\u8ddf\u81ea\u5df1\u4e5f\u4f1a\u6709\u4e00\u4e2a\u76f8\u5173\u6027$w_{cat,cat}$\u3002<\/p>\n\n\n\n<p>\u5b9e\u9645\u5e94\u7528\u4e2d\u4e00\u4e2a\u8bcd\u4f1aembedding\u5230\u4e00\u4e2a\u5411\u91cf\uff0c\u5047\u5982\u5411\u91cf\u7684\u7ef4\u5ea6$d_{model}=16$\uff0c&#8221;the cat set on mat&#8221;\u4f1a\u88ab\u6620\u5c04\u4e3a$\\text {inputs} \\in R^{5\\times 16}$\u7684\u77e9\u9635\uff0c\u5bf9\u8fd9\u4e2a\u77e9\u9635\u6267\u884c\u4e09\u4e2a\u7ebf\u6027\u53d8\u6362\u5f97\u5230$Q, K, V$\uff0c$Q = \\text {inputs} \\cdot W_Q^{T}$\uff0c$K = \\text {inputs} \\cdot W_K^{T}$\uff0c$V = \\text {inputs} \\cdot W_V^{T}$<\/p>\n\n\n\n<p>\u5176\u4e2d$W_{Q}\\in R^{d_{model} \\times d_{model}}, W_{K}\\in R^{d_{model} \\times d_{model}}, W_{V}\\in R^{d_{model} \\times d_{model}}$,\u5728pytorh\u7684\u5b9e\u73b0\u4e2d$W_Q, W_K, W_V$\u88ab\u6574\u5408\u6210\u4e00\u4e2a\u53ebin_proj_weights\u7684\u53c2\u6570\u53d8\u91cf\uff0c\u5b83\u7684\u7ef4\u5ea6\u662f$R^{3 * d_{model}} \\times d_{model}$\uff0c\u6bd5\u7adf\u6267\u884c\u4e00\u6b21\u77e9\u9635\u8fd0\u7b97\u7684\u6548\u7387\u4f1a\u66f4\u9ad8\uff0c\u540e\u9762\u5206\u5f00\u5c31\u662f\u6211\u4eec\u9700\u8981\u7684$Q,K,V$\u4e86\uff0c\u5c31\u662fquery, key, value. QKV\u4e09\u4e2a\u77e9\u9635\u7684\u6bcf\u4e00\u884c\u90fd\u662f\u539f\u59cb\u7684inputs\u7684\u6bcf\u4e00\u884c\u7ecf\u8fc7\u5148\u884c\u53d8\u6362\u5f97\u5230\u7684\uff0c\u6bcf\u4e00\u884c\u4f9d\u65e7\u4ee3\u8868\u7684\u662f\u6bcf\u4e2a\u8bcd\u3002\u62ff\u51faQ\u7684\u7b2c\u4e00\u884c$q \\in R^{1 \\times 16}$\u8ddfK\u7684\u6bcf\u4e00\u884c$K \\in R^{5 \\times 16}$\u8ba1\u7b97\u5185\u79ef\uff0c\u5f97\u5230\u4e00\u4e2a\u65b0\u7684\u5411\u91cf$a \\in R^{1 \\times 5}$, \u5b83\u8868\u793a\u7684\u5c31\u662f\u7b2c\u4e00\u4e2aword\u8ddf\u5176\u4ed6\u6240\u6709\u7684\u8bcd\u7684\u76f8\u5173\u6027\uff0c\u7ecf\u8fc7softmax\u5c06\u8fd9\u4e2a\u76f8\u5173\u6027\u6807\u51c6\u5316\u52300-1\u4e4b\u95f4\u3002\u6309\u7167\u8fd9\u4e2a\u8fc7\u7a0b\u6240\u6709\u5355\u8bcd\u90fd\u8981\u8ba1\u7b97\u4e00\u6b21\u5f7c\u6b64\u7684\u76f8\u5173\u6027\uff0c\u4f1a\u5f97\u52305\u4e2a\u8fd9\u6837\u7684\u5411\u91cf\uff0c\u6574\u4e2a\u8fc7\u7a0b\u5176\u5b9e\u5c31\u662f$Q \\cdot K^{T} = A\\in R^{5 \\times 5}$\u3002<br>\u5982\u679cembedding\u5411\u91cf$d_{model}$\u975e\u5e38\u5927\uff0c\u4e24\u4e2a\u5411\u91cf\u8ba1\u7b97\u5185\u79ef\u53ef\u80fd\u4f1a\u4ea7\u751f\u5f88\u5927\u5f88\u5c0f\u7684\u6570\uff0c\u8fdb\u5165softmax\u9971\u548c\u533a\u95f4\uff0c\u5f71\u54cd\u68af\u5ea6\u53cd\u5411\u4f20\u64ad\uff0c\u6240\u4ee5\u52a0\u5165\u4e86\u4e00\u4e2a\u7f29\u653e\u7cfb\u6570$\\sqrt{d_k}$\uff0c\u8fd9\u4e2a\u7f29\u653e\u7cfb\u6570\u6765\u7406\u89e3\u4e00\u4e0b\u3002\u5982\u679c\u4e24\u4e2a\u6b63\u592a\u5206\u5e03\u7684\u968f\u673a\u53d8\u91cfA,B\uff0c$A\\in R^{1000 \\times 64}, B\\in R^{1000 \\times 64}$\uff0c\u8fd9\u91cc\u6b63\u592a\u5206\u5e03\u7684\u610f\u601d\u662f\u77e9\u9635$A$\u4e2d\u7684\u5143\u7d20\u7ea7\u7684\u7b26\u5408\u6b63\u6001\u5206\u5e03\uff0c\u4e0d\u662f\u8bf4\u4ed6\u662f64\u7ef4\u7684\u591a\u53d8\u91cf\u6b63\u6001\u5206\u5e03\uff0c\u6240\u4ee5\u4f60\u628a\u77e9\u9635flatten\u523064000\u7ef4\u7684\u5411\u91cf\uff0c\u5b83\u4f9d\u65e7\u662f\u6b63\u592a\u5206\u5e03\u3002\u6839\u636e\u968f\u673a\u53d8\u91cf\u7684\u6027\u8d28$\\text {var}(A, B) = \\text {var} (A) + \\text{var} (B)$\uff0c\u4e3e\u4e2a\u4f8b\u5b50\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"a = np.random.normal(size=10000)\nb = np.random.normal(size=10000)\nnp.mean(a), np.var(a)\n&gt;&gt;&gt; 0.012120657907471737, 1.0050670288852406\nnp.mean(b), np.var(b)\n&gt;&gt; -0.0022273866402510775, 1.0149182079551393\nnp.mean(a + b), np.var(a + b)\n&gt;&gt; 0.009893271267220663, 1.9960134883163734\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #D8DEE9FF\">a <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">random<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">normal<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9\">size<\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #B48EAD\">10000<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">b <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">random<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">normal<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9\">size<\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #B48EAD\">10000<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">mean<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">a<\/span><span style=\"color: #ECEFF4\">),<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">var<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">a<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #81A1C1\">&gt;&gt;&gt;<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0.012120657907471737<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1.0050670288852406<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">mean<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">b<\/span><span style=\"color: #ECEFF4\">),<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">var<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">b<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #81A1C1\">&gt;&gt;<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">0.0022273866402510775<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1.0149182079551393<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">mean<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">a <\/span><span style=\"color: #81A1C1\">+<\/span><span style=\"color: #D8DEE9FF\"> b<\/span><span style=\"color: #ECEFF4\">),<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">var<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">a <\/span><span style=\"color: #81A1C1\">+<\/span><span style=\"color: #D8DEE9FF\"> b<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #81A1C1\">&gt;&gt;<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0.009893271267220663<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1.9960134883163734<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u4e0a\u9762\u7684\u4f8b\u5b50\u662f\u4f53\u73b0\u4e24\u4e2a\u968f\u673a\u53d8\u91cf\u548c\u7684\u65b9\u5dee\uff0c\u77e9\u9635\u8fd0\u7b97\u7684\u8bdd\u770b\u4e0b\u9762\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"A = np.random.normal(size=(10000, 64))\nB = np.random.normal(size=(10000, 64))\nnp.mean(np.matmul(A, B.T)), np.var(np.matmul(A, B.T))\n&gt;&gt;&gt; 6.668697207603404e-05, 64.29069222735822\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #D8DEE9FF\">A <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">random<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">normal<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9\">size<\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">10000<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">64<\/span><span style=\"color: #ECEFF4\">))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">B <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">random<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">normal<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9\">size<\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">10000<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">64<\/span><span style=\"color: #ECEFF4\">))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">mean<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">A<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> B<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">)),<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">var<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">A<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> B<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #81A1C1\">&gt;&gt;&gt;<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">6.668697207603404e-05<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">64.29069222735822<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u8fd9\u91cc\u65b9\u5dee\u662f64\uff0c\u56e0\u4e3a$A\\cdot B^T$\u7684\u6bcf\u4e2a\u5143\u7d20\uff0c\u672c\u8eab\u662f$A$\u7684\u4e00\u4e2a\u5143\u7d2064\u7ef4\u7684\u5411\u91cf\u548c$B$\u7684\u4e00\u4e2a\u5143\u7d20\u768464\u7ef4\u7684\u5185\u79ef<\/p>\n\n\n\n<p>$a_{1,1}b_{1,1} + a_{1,2}b_{2,1} + \u2026 + a_{1,64}b_{64,1}$<\/p>\n\n\n\n<p>\u8fd9\u91cc\u9762\u6709\u4e24\u4e2a\u53d8\u91cf\u76f8\u4e58$\\text{var}(a\\cdot b)=\\text{var}(a) \\cdot \\text{var}(b)$\uff0c\u628a$ab$\u770b\u6210\u65b0\u7684\u968f\u673a\u53d8\u91cf\uff0c\u52a0\u4e8664\u6b21\uff0c\u6240\u4ee5\u65b9\u5dee\u53d8\u621064\u4e86\u3002\u4e8e\u662f\u6211\u4eec\u7ed9\u4ed6\u4e00\u4e2a\u7f29\u653e\u7cfb\u6570$\\sqrt{64}$\uff0c\u4f1a\u628a\u65b9\u5dee\u7f29\u5c0f\u52301:<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"np.var(np.matmul(A, B.T) \/ np.sqrt(64))\n&gt;&gt;&gt; 1.0045420660524722\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">var<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">A<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> B<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">)<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">\/<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">sqrt<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">64<\/span><span style=\"color: #ECEFF4\">))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #81A1C1\">&gt;&gt;&gt;<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1.0045420660524722<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u8fd9\u4e2a\u4f8b\u5b50\u4e3b\u8981\u662f\u60f3\u8bf4\u660e\u968f\u7740embedding\u7ef4\u5ea6\u589e\u52a0\u6570\u636e\u96c6\u7684\u65b9\u5dee\u4f1a\u53d8\u7684\u5f88\u5927\uff0c\u5c31\u662f\u5f02\u5e38\u6570\u636e\u4f1a\u53d8\u591a\uff0c\u52a0\u5165\u7f29\u653e\u7cfb\u6570\uff0c\u4f1a\u907f\u514d\u51fa\u73b0\u7279\u522b\u5927\uff08\u7279\u522b\u5c0f\uff09\u7684\u503c\uff0c\u907f\u514d\u8fd9\u90e8\u5206\u7684\u68af\u5ea6\u6d88\u5931\uff0c\u65e0\u6cd5\u53cd\u5411\u4f20\u64ad\u3002\u5927\u6982\u662f\u8fd9\u6837\u7684\u3002<\/p>\n\n\n\n<p>\u65e2\u7136\u6709in_proj_weights\u90a3\u5fc5\u7136\u4f1a\u6709out_proj_weights\uff0c\u5728pytorch\u5b9e\u73b0\u91cc\uff0c\u5728\u6267\u884c\u5b8cmulti-head attention\u540e\uff0c\u4f1a\u5bf9\u8f93\u51fa\u6267\u884c\u4e00\u6b21\u9650\u884c\u53d8\u6362\uff0c\u5f97\u5230\u6700\u7ec8\u8f93\u51fa\u3002\u5148\u7528numpy\u5b9e\u73b0\u4e00\u4e0b\u53ea\u6709\u4e00\u4e2a\u5934\u7684self attention<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"def attention(inputs, in_proj_weights, out_proj_weights):\n    # inputs's shape is (seq_len, batch_size, embed_dim)\n    # in_proj_weights's is (embed_dim * 3, embed_dim)\n    embed_dim = inputs.shape[-1]\n    # The Q, K and V matrices are derived from the inputs through linear transformations.\n    qkv = np.matmul(inputs, in_proj_weights.T)\n    Q = qkv[:, :, :embed_dim]\n    K = qkv[:, :, embed_dim:embed_dim * 2]\n    V = qkv[:, :, embed_dim * 2:]\n\n    # To simplify calculations, the shapes of Q, K and V matrices are reshaped (batch_size, seq_len, embed_dim).\n    Q = np.swapaxes(Q, 0, 1)\n    K = np.swapaxes(K, 0, 1)\n    V = np.swapaxes(V, 0, 1)\n\n    # The attention weights are computed from the dot product of Q and K matrices.\n    atten_scores = np.matmul(Q, np.swapaxes(K, -2, -1)) # The shape of K is reshaped (batch_size, embed_dim, seq_len) when the product of two matrices.\n    scaled_scores = atten_score \/ np.sqrt(embed_dim) # scaled dot product\n    atten_weights = softmax(scaled_scores, axis=-1)\n\n    # Each query's corresponding value is calculated as the weighted sum of the V matrix, using the attention weights(atten_weights).\n    atten_output = np.matmul(atten_weights, V)\n\n    output = np.matmul(atten_output, out_proj_weights.T)\n    return np.swapaxes(output, 0, 1)\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #81A1C1\">def<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">attention<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9\">inputs<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">in_proj_weights<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">out_proj_weights<\/span><span style=\"color: #ECEFF4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># inputs&#39;s shape is (seq_len, batch_size, embed_dim)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># in_proj_weights&#39;s is (embed_dim * 3, embed_dim)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    embed_dim <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> inputs<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">shape<\/span><span style=\"color: #ECEFF4\">[<\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># The Q, K and V matrices are derived from the inputs through linear transformations.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    qkv <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> in_proj_weights<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> qkv<\/span><span style=\"color: #ECEFF4\">[:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:<\/span><span style=\"color: #D8DEE9FF\">embed_dim<\/span><span style=\"color: #ECEFF4\">]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    K <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> qkv<\/span><span style=\"color: #ECEFF4\">[:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:,<\/span><span style=\"color: #D8DEE9FF\"> embed_dim<\/span><span style=\"color: #ECEFF4\">:<\/span><span style=\"color: #D8DEE9FF\">embed_dim <\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">2<\/span><span style=\"color: #ECEFF4\">]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    V <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> qkv<\/span><span style=\"color: #ECEFF4\">[:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:,<\/span><span style=\"color: #D8DEE9FF\"> embed_dim <\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">2<\/span><span style=\"color: #ECEFF4\">:]<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># To simplify calculations, the shapes of Q, K and V matrices are reshaped (batch_size, seq_len, embed_dim).<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">Q<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    K <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">K<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    V <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">V<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># The attention weights are computed from the dot product of Q and K matrices.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_scores <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">Q<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">K<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">2<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">))<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #616E88\"># The shape of K is reshaped (batch_size, embed_dim, seq_len) when the product of two matrices.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    scaled_scores <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> atten_score <\/span><span style=\"color: #81A1C1\">\/<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">sqrt<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">embed_dim<\/span><span style=\"color: #ECEFF4\">)<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #616E88\"># scaled dot product<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_weights <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">softmax<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">scaled_scores<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">axis<\/span><span style=\"color: #81A1C1\">=-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># Each query&#39;s corresponding value is calculated as the weighted sum of the V matrix, using the attention weights(atten_weights).<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_output <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">atten_weights<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> V<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    output <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">atten_output<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> out_proj_weights<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #81A1C1\">return<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">output<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>multi-head attention\u6765\u753b\u4e00\u4e0b\u53ef\u80fd\u4f1a\u66f4\u597d\u7406\u89e3\uff0c\u4e00\u4e2a\u5934\u7684:<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"490\" src=\"https:\/\/tensor.agenthub.uk\/wp-content\/uploads\/2023\/12\/image-2-1024x490.png\" alt=\"\" class=\"wp-image-43\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-2-1024x490.png 1024w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-2-300x144.png 300w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-2-768x368.png 768w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-2.png 1444w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u4e00\u4e2a\u6837\u672c$\\text{inputs} \\in R^{6 \\times 16}$\u7ecf\u8fc7$W_q, W_k, W_v \\in R^{16 * 16}$\u7ebf\u6027\u53d8\u6362\u5230$Q, K, V \\in R^{6 \\times 16}$\uff0c\u6839\u636e\u77e9\u9635\u8fd0\u7b97\u7684\u6027\u8d28\uff0c$Q, K, V$\u7684\u6bcf\u4e00\u884c\u90fd\u662f\u539f\u8f93\u5165\u4e2d\u6bcf\u4e2a\u8bcd\u7684\u53d8\u6362\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"453\" src=\"https:\/\/tensor.agenthub.uk\/wp-content\/uploads\/2023\/12\/image-3-1024x453.png\" alt=\"\" class=\"wp-image-44\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-3-1024x453.png 1024w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-3-300x133.png 300w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-3-768x340.png 768w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-3-1536x680.png 1536w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-3-2048x906.png 2048w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u591a\u5934\u7684\u8ba1\u7b97\u8fc7\u7a0b\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"751\" src=\"https:\/\/tensor.agenthub.uk\/wp-content\/uploads\/2023\/12\/image-4-1024x751.png\" alt=\"\" class=\"wp-image-45\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-4-1024x751.png 1024w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-4-300x220.png 300w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-4-768x563.png 768w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-4.png 1404w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u591a\u5934\u7684\u7a0d\u5fae\u6709\u4e9b\u4e0d\u540c\uff0c\u5728\u5f97\u5230$Q, K, V$\u4e4b\u540e\uff0c\u628a$Q, K, V$ \u5206\u6210\u4e86\u591a\u4e2a\u5b50\u77e9\u9635\uff0c\u6309\u5217\u5206\u5f00\u7684\uff0c\u4ece\u4e0a\u9762\u7684\u77e9\u9635\u4e58\u6cd5\u793a\u610f\u56fe\u53ef\u4ee5\u770b\u51fa\u6765\uff0c$Q$\u7684\u7b2c\u4e00\u5217\uff0c\u662f$W_q^{T}$\u7684\u7b2c\u4e00\u5217\u8ddf\u6240\u6709inputs\u7b97\u7684\u5185\u79ef.\u76f4\u767d\u4e00\u70b9\u5c31\u662f\u4e4b\u524d\u7ebf\u6027\u53d8\u6362\u7684\u662f<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"linear_q = Linear(16, 16)\nQ = linear_q(inputs)\nQ.shape\n&gt;&gt;&gt; (6, 16)\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #D8DEE9FF\">linear_q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">Linear<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">linear_q<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">shape<\/span><\/span>\n<span class=\"line\"><span style=\"color: #81A1C1\">&gt;&gt;&gt;<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">6<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<p>\u73b0\u5728\u53d8\u6210\u4e86<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"linear_q_h1 = Linear(16, 4)\nlinear_q_h2 = Linear(16, 4)\nlinear_q_h3 = Linear(16, 4)\nlinear_q_h4 = Linear(16, 4)\nQ1 = linear_q_h1(inputs)\nQ2 = linear_q_h2(inputs)\nQ3 = linear_q_h3(inputs)\nQ4 = linear_q_h4(inputs)\nQ = torch.cat([Q1, Q2, Q3, Q4], dim=-1)\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #D8DEE9FF\">linear_q_h1 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">Linear<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">4<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">linear_q_h2 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">Linear<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">4<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">linear_q_h3 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">Linear<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">4<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">linear_q_h4 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">Linear<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #B48EAD\">16<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">4<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q1 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">linear_q_h1<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q2 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">linear_q_h2<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q3 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">linear_q_h3<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q4 <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">linear_q_h4<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> torch<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">cat<\/span><span style=\"color: #ECEFF4\">([<\/span><span style=\"color: #D8DEE9FF\">Q1<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> Q2<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> Q3<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> Q4<\/span><span style=\"color: #ECEFF4\">],<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">dim<\/span><span style=\"color: #81A1C1\">=-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span><\/code><\/pre><\/div>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"738\" src=\"https:\/\/tensor.agenthub.uk\/wp-content\/uploads\/2023\/12\/image-5-1024x738.png\" alt=\"\" class=\"wp-image-46\" srcset=\"https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-5-1024x738.png 1024w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-5-300x216.png 300w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-5-768x553.png 768w, https:\/\/tensorzen.blog\/wp-content\/uploads\/2023\/12\/image-5.png 1416w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u4e4b\u524d\u7684$W_q \\in R^{16 \\times 16}$\uff0c\u53d8\u62104\u4e2a$W_q\\in R^{16 \\times 4}$\uff0c\u8fd9\u6837\u5c31\u5f97\u5230\u4e864\u4e2a$Q \\in R^{6 \\times 4}$\uff0c\u518d\u8fd9\u6837\u751f\u62104\u4e2a$K \\in R^{6 \\times 4}, V \\in R^{6 \\times 4}$, \u5f97\u52304\u4e2a$A \\in R^{6 \\times 4}$\uff0c\u6700\u540e\u628a4\u4e2a$A$contact\u8d77\u6765\u4e5f\u662f$A \\in R^{6 \\times 16}$\uff0c\u6240\u4ee5\u76f8\u6bd4\u76f4\u63a5\u5b66\u4e60$W_q \\in R^{16 \\times 16}$\u4e0d\u5982\u5b66\u4e604\u4e2a$W_q \\in R^{16 \\times 4}$\u7684\u6548\u679c\u597d\uff08\u539f\u8bba\u6587\u8bf4\u7684\uff09\u3002<\/p>\n\n\n\n<p>\u4ee3\u7801\u5982\u4e0b\uff0c\u662f\u7c7b\u4f3cpytorch\u7684\u5b9e\u73b0\u8fc7\u7a0b\uff0c\u4e2d\u95f4\u6709\u5f88\u591a\u77e9\u9635\u8f6c\u7f6e\u7684\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<div class=\"wp-block-kevinbatdorf-code-block-pro\" data-code-block-pro-font-family=\"Code-Pro-JetBrains-Mono\" style=\"font-size:.875rem;font-family:Code-Pro-JetBrains-Mono,ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,monospace;line-height:1.25rem;--cbp-tab-width:2;tab-size:var(--cbp-tab-width, 2)\"><span style=\"display:block;padding:16px 0 0 16px;margin-bottom:-1px;width:100%;text-align:left;background-color:#2e3440ff\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" width=\"54\" height=\"14\" viewBox=\"0 0 54 14\"><g fill=\"none\" fill-rule=\"evenodd\" transform=\"translate(1 1)\"><circle cx=\"6\" cy=\"6\" r=\"6\" fill=\"#FF5F56\" stroke=\"#E0443E\" stroke-width=\".5\"><\/circle><circle cx=\"26\" cy=\"6\" r=\"6\" fill=\"#FFBD2E\" stroke=\"#DEA123\" stroke-width=\".5\"><\/circle><circle cx=\"46\" cy=\"6\" r=\"6\" fill=\"#27C93F\" stroke=\"#1AAB29\" stroke-width=\".5\"><\/circle><\/g><\/svg><\/span><span role=\"button\" tabindex=\"0\" data-code=\"def multihead_atten(inputs, num_heads, in_proj_weights, out_proj_weights):\n    # inputs's shape is (seq_len, batch_size, ebmed_dim)\n    (seq_len, batch_size, embed_dim) = inputs.shape\n    \n    assert embed_dim % num_heads == 0, 'embed_dim must be divisible by num_heads'\n    head_dim = embed_dim \/\/ num_heads\n    qkv = np.matmul(inputs, in_proj_weights.T)\n    Q = qkv[:, :, :embed_dim]\n    K = qkv[:, :, embed_dim:embed_dim*2]\n    V = qkv[:, :, embed_dim*2:]    \n    \n    # Reshape Q, K, V into (seq_len, num_head * batch_size, head_dim)\n    # Partitioned the Q matrix into num_heads submatrices [Q1, Q2, Q3, Q4, ...] along column axis, \n    # where each Qi belongs to seq_len times head_dim.\n    # for examples, partitioned a 2x8 matrix into 2 2x4 matrices, 2 heads.\n    # [1, 2, 3, 4, 5, 6, 7 ,8]   [1, 2, 3, 4] [5, 6, 7, 8]\n    # [9, 8, 7, 6, 5, 4, 3, 2]   [9, 8, 7, 6] [5, 4, 3, 2]\n    # So matrix will be a 4x4 matrix:\n    # [1, 2, 3, 4]\n    # [5, 6, 7, 8]\n    # [9, 8, 7, 6]\n    # [5, 4, 3, 2]\n    Q = Q.reshape(seq_len, num_heads * batch_size, head_dim)\n    K = K.reshape(seq_len, num_heads * batch_size, head_dim)\n    V = V.reshape(seq_len, num_heads * batch_size, head_dim)\n\n    # Current the shape of Q, K or V is (seq_len, num_head * batch_size, head_dim)\n    # Permute the dimensions of Q, K and V into (num_head * batch_size, seq_len, head_dim)\n    Q = np.swapaxes(Q, 0, 1)\n    K = np.swapaxes(K, 0, 1)\n    V = np.swapaxes(V, 0, 1)\n\n    # attention\n    atten_scores = np.matmul(Q, np.swapaxes(K, -2, -1))\n    scaled_scores = atten_scores \/ np.sqrt(K.shape[-1])\n    atten_weights = softmax(scaled_scores, axis=-1)\n    \n    atten_output = np.matmul(atten_weights, V)\n\n    # We need to calculate linear transformation of atten_output, so permute it again.\n    atten_output = np.swapaxes(atten_output, 0, 1)\n    atten_output = atten_output.reshape(seq_len, batch_size, -1)\n    output = np.matmul(atten_output, out_proj_weights.T)\n    return output\" style=\"color:#d8dee9ff;display:none\" aria-label=\"Copy\" class=\"code-block-pro-copy-button\"><svg xmlns=\"http:\/\/www.w3.org\/2000\/svg\" style=\"width:24px;height:24px\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\" stroke-width=\"2\"><path class=\"with-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2m-6 9l2 2 4-4\"><\/path><path class=\"without-check\" stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M9 5H7a2 2 0 00-2 2v12a2 2 0 002 2h10a2 2 0 002-2V7a2 2 0 00-2-2h-2M9 5a2 2 0 002 2h2a2 2 0 002-2M9 5a2 2 0 012-2h2a2 2 0 012 2\"><\/path><\/svg><\/span><pre class=\"shiki nord\" style=\"background-color: #2e3440ff\" tabindex=\"0\"><code><span class=\"line\"><span style=\"color: #81A1C1\">def<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">multihead_atten<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9\">inputs<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">num_heads<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">in_proj_weights<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">out_proj_weights<\/span><span style=\"color: #ECEFF4\">):<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># inputs&#39;s shape is (seq_len, batch_size, ebmed_dim)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">seq_len<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> batch_size<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> embed_dim<\/span><span style=\"color: #ECEFF4\">)<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> inputs<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">shape<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #81A1C1\">assert<\/span><span style=\"color: #D8DEE9FF\"> embed_dim <\/span><span style=\"color: #81A1C1\">%<\/span><span style=\"color: #D8DEE9FF\"> num_heads <\/span><span style=\"color: #81A1C1\">==<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">&#39;<\/span><span style=\"color: #A3BE8C\">embed_dim must be divisible by num_heads<\/span><span style=\"color: #ECEFF4\">&#39;<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    head_dim <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> embed_dim <\/span><span style=\"color: #81A1C1\">\/\/<\/span><span style=\"color: #D8DEE9FF\"> num_heads<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    qkv <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">inputs<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> in_proj_weights<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> qkv<\/span><span style=\"color: #ECEFF4\">[:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:<\/span><span style=\"color: #D8DEE9FF\">embed_dim<\/span><span style=\"color: #ECEFF4\">]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    K <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> qkv<\/span><span style=\"color: #ECEFF4\">[:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:,<\/span><span style=\"color: #D8DEE9FF\"> embed_dim<\/span><span style=\"color: #ECEFF4\">:<\/span><span style=\"color: #D8DEE9FF\">embed_dim<\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #B48EAD\">2<\/span><span style=\"color: #ECEFF4\">]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    V <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> qkv<\/span><span style=\"color: #ECEFF4\">[:,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #ECEFF4\">:,<\/span><span style=\"color: #D8DEE9FF\"> embed_dim<\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #B48EAD\">2<\/span><span style=\"color: #ECEFF4\">:]<\/span><span style=\"color: #D8DEE9FF\">    <\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># Reshape Q, K, V into (seq_len, num_head * batch_size, head_dim)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># Partitioned the Q matrix into num_heads submatrices [Q1, Q2, Q3, Q4, ...] along column axis, <\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># where each Qi belongs to seq_len times head_dim.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># for examples, partitioned a 2x8 matrix into 2 2x4 matrices, 2 heads.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># [1, 2, 3, 4, 5, 6, 7 ,8]   [1, 2, 3, 4] [5, 6, 7, 8]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># [9, 8, 7, 6, 5, 4, 3, 2]   [9, 8, 7, 6] [5, 4, 3, 2]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># So matrix will be a 4x4 matrix:<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># [1, 2, 3, 4]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># [5, 6, 7, 8]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># [9, 8, 7, 6]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># [5, 4, 3, 2]<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> Q<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">reshape<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">seq_len<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> num_heads <\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #D8DEE9FF\"> batch_size<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> head_dim<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    K <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> K<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">reshape<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">seq_len<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> num_heads <\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #D8DEE9FF\"> batch_size<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> head_dim<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    V <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> V<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">reshape<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">seq_len<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> num_heads <\/span><span style=\"color: #81A1C1\">*<\/span><span style=\"color: #D8DEE9FF\"> batch_size<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> head_dim<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># Current the shape of Q, K or V is (seq_len, num_head * batch_size, head_dim)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># Permute the dimensions of Q, K and V into (num_head * batch_size, seq_len, head_dim)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    Q <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">Q<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    K <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">K<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    V <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">V<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># attention<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_scores <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">Q<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">K<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">2<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">))<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    scaled_scores <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> atten_scores <\/span><span style=\"color: #81A1C1\">\/<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">sqrt<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">K<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">shape<\/span><span style=\"color: #ECEFF4\">[<\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">])<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_weights <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #88C0D0\">softmax<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">scaled_scores<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #D8DEE9\">axis<\/span><span style=\"color: #81A1C1\">=-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_output <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">atten_weights<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> V<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #616E88\"># We need to calculate linear transformation of atten_output, so permute it again.<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_output <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">swapaxes<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">atten_output<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">0<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    atten_output <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> atten_output<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">reshape<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">seq_len<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> batch_size<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> <\/span><span style=\"color: #81A1C1\">-<\/span><span style=\"color: #B48EAD\">1<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    output <\/span><span style=\"color: #81A1C1\">=<\/span><span style=\"color: #D8DEE9FF\"> np<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #88C0D0\">matmul<\/span><span style=\"color: #ECEFF4\">(<\/span><span style=\"color: #D8DEE9FF\">atten_output<\/span><span style=\"color: #ECEFF4\">,<\/span><span style=\"color: #D8DEE9FF\"> out_proj_weights<\/span><span style=\"color: #ECEFF4\">.<\/span><span style=\"color: #D8DEE9FF\">T<\/span><span style=\"color: #ECEFF4\">)<\/span><\/span>\n<span class=\"line\"><span style=\"color: #D8DEE9FF\">    <\/span><span style=\"color: #81A1C1\">return<\/span><span style=\"color: #D8DEE9FF\"> output<\/span><\/span><\/code><\/pre><\/div>\n","protected":false},"excerpt":{"rendered":"<p>\u76f4\u89c9\u7684\u7406\u89e3Attention\u548cMulti-Head Attention\u7684\u8ba1\u7b97\u8fc7\u7a0b\uff0c\u7136\u540e\u54b1\u4eec\u7528NumPy\u6765\u5b9e\u73b0\u4e0b\u3002<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[16,12,4],"tags":[],"class_list":["post-8","post","type-post","status-publish","format-standard","hentry","category-base","category-llm","category-machine-learning"],"_links":{"self":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/8","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=8"}],"version-history":[{"count":24,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/8\/revisions"}],"predecessor-version":[{"id":505,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=\/wp\/v2\/posts\/8\/revisions\/505"}],"wp:attachment":[{"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=8"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=8"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/tensorzen.blog\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=8"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}