From 9fac2fca7f6cb32b294e768f5e67d7bec3b2b7d0 Mon Sep 17 00:00:00 2001 From: ZHOUhuichi <96902323+ZHOUhuichi@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:21:50 +0800 Subject: [PATCH 1/6] Add files via upload --- .../NLP\345\237\272\347\241\200.ipynb" | 687 ++++++++++++++++++ 1 file changed, 687 insertions(+) create mode 100644 "docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" diff --git "a/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" "b/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" new file mode 100644 index 000000000..ba5bcc1f9 --- /dev/null +++ "b/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" @@ -0,0 +1,687 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "119ec186", + "metadata": {}, + "source": [ + "# 词嵌入(概念部分)" + ] + }, + { + "cell_type": "markdown", + "id": "f8e5639e", + "metadata": {}, + "source": [ + "###   在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入?
\n", + " 计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。
\n", + " 我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1]
\n", + " 例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
\n", + "\n", + "
一个小朋友刚学会了四个字和词-->[我] [特别] [喜欢] [学习]
\n", + " 我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码
\n", + " 对于中文 我们先进行分词 我 特别 喜欢 学习
\n", + " 那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1]
\n", + " 现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1]

\n", + " 我们可以思考几个问题:
\n", + " 1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。
\n", + " 2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。
\n", + " 因此我们认为独热编码不是一个很好的词嵌入方法。
\n", + "\n", + " 我们再来介绍一下 **稠密表示**
\n", + "\n", + " 稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97]" + ] + }, + { + "cell_type": "markdown", + "id": "4db86da3", + "metadata": {}, + "source": [ + "# Bag of Words词袋表示" + ] + }, + { + "cell_type": "markdown", + "id": "44dc9252", + "metadata": {}, + "source": [ + "  词袋表示顾名思义,我们往一个袋子中装入我们的词汇,构成一个词袋,当我们想表达的时候,我们将其取出,构建词袋的方法可以有如下形式。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "823f8f2d", + "metadata": {}, + "outputs": [], + "source": [ + "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]#我们的语料库\n", + "word_list = ' '.join(corpus).split()\n", + "word_list = list(sorted(set(word_list)))\n", + "word_dict = {w: i for i, w in enumerate(word_list)}\n", + "number_dict = {i: w for i, w in enumerate(word_list)}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8eaeb37d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'do': 0,\n", + " 'drinking': 1,\n", + " 'hate': 2,\n", + " 'i': 3,\n", + " 'like': 4,\n", + " 'love': 5,\n", + " 'nlp': 6,\n", + " 'playing': 7,\n", + " 'reading': 8}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "word_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2bf380c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: 'do',\n", + " 1: 'drinking',\n", + " 2: 'hate',\n", + " 3: 'i',\n", + " 4: 'like',\n", + " 5: 'love',\n", + " 6: 'nlp',\n", + " 7: 'playing',\n", + " 8: 'reading'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "number_dict" + ] + }, + { + "cell_type": "markdown", + "id": "90e0ef43", + "metadata": {}, + "source": [ + " 根据如上形式,我们可以构建一个维度为9的one-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9821ed2a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "voc_size = len(word_dict)\n", + "bow = []\n", + "for i,name in enumerate(word_dict):\n", + " bow.append(np.eye(voc_size)[word_dict[name]])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "03f1f12f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([1., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 1., 0., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 1., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 1., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 1., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 1., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 1., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 0., 1., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 0., 0., 1.])]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bow" + ] + }, + { + "cell_type": "markdown", + "id": "086a5fd2", + "metadata": {}, + "source": [ + "# N-gram:基于统计的语言模型\n", + " N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。\n", + "\n", + " N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 \n", + "\n", + "**N-gram 模型的优点:**\n", + "\n", + "简单易用,N-gram 模型的概念非常简单,实现起来也很容易。 \n", + "能够捕捉到语言中的相关性,N-gram 模型通过考虑连续出现的 n 个词语来预测下一个词语,因此它能够捕捉到语言中词语之间的相关性。 \n", + "可以使用已有的语料库进行训练,N-gram 模型可以使用已有的大量语料库进行训练,例如 Google 的 N-gram 数据库,这样可以大大提高模型的准确性。 \n", + "\n", + "**N-gram 模型的缺点:**\n", + "\n", + "对于短文本数据集不适用,N-gram 模型需要大量的文本数据进行训练,因此对于短文本数据集可能无法达到较高的准确性。 \n", + "容易受到噪声和语言纠错的影响,N-gram 模型是基于语料库进行训练的,如果语料库中包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。 \n", + "无法捕捉到语言中的非线性关系,N-gram 模型假设语言中的关系是线性的,但事实上语言中可能存在复杂的非线性关系,N-gram 模型无法捕捉到这些关系。" + ] + }, + { + "cell_type": "markdown", + "id": "1f5ad65b", + "metadata": {}, + "source": [ + "# NNLM:前馈神经网络语言模型\n", + " 下面通过前馈神经网络模型来**展示滑动**窗口的使用" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7bddfa77", + "metadata": {}, + "outputs": [], + "source": [ + "#导入必要的库\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from tqdm import tqdm\n", + "from torch.autograd import Variable\n", + "dtype = torch.FloatTensor" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "29f23588", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i',\n", + " 'like',\n", + " 'reading',\n", + " 'i',\n", + " 'love',\n", + " 'drinking',\n", + " 'i',\n", + " 'hate',\n", + " 'playing',\n", + " 'i',\n", + " 'do',\n", + " 'nlp']" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]\n", + "\n", + "word_list = ' '.join(corpus).split()\n", + "word_list" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "12b58886", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 1000 cost = 1.010682\n", + "epoch: 2000 cost = 0.695155\n", + "epoch: 3000 cost = 0.597085\n", + "epoch: 4000 cost = 0.531892\n", + "epoch: 5000 cost = 0.376044\n", + "epoch: 6000 cost = 0.118038\n", + "epoch: 7000 cost = 0.077081\n", + "epoch: 8000 cost = 0.053636\n", + "epoch: 9000 cost = 0.038089\n", + "epoch: 10000 cost = 0.027224\n", + "[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'do']] -> ['studying', 'datawhale', 'playing', 'nlp']\n" + ] + } + ], + "source": [ + "#构建我们需要的语料库\n", + "corpus = [\"i like studying\", \"i love datawhale\", \"i hate playing\", \"i do nlp\"]\n", + "\n", + "word_list = ' '.join(corpus).split() #将语料库转化为一个个单词 ,如['i', 'like', 'reading', 'i', ...,'nlp']\n", + "word_list = list(sorted(set(word_list))) #用set去重后转化为链表\n", + "# print(word_list)\n", + "\n", + "word_dict = {w: i for i, w in enumerate(word_list)} #将词表转化为字典 这边是词对应到index\n", + "number_dict = {i: w for i, w in enumerate(word_list)}#这边是index对应到词\n", + "# print(word_dict)\n", + "# print(number_dict)\n", + "\n", + "n_class = len(word_dict) #计算出我们词表的大小,用于后面词向量的构建\n", + "\n", + "m = 2 #词嵌入维度\n", + "n_step = 2 #滑动窗口的大小\n", + "n_hidden = 2 #隐藏层的维度为2\n", + "\n", + "\n", + "def make_batch(sentence): #由于语料库较小,我们象征性将训练集按照批次处理 \n", + " input_batch = []\n", + " target_batch = []\n", + "\n", + " for sen in sentence:\n", + " word = sen.split()\n", + " input = [word_dict[n] for n in word[:-1]]\n", + " target = word_dict[word[-1]]\n", + "\n", + " input_batch.append(input)\n", + " target_batch.append(target)\n", + "\n", + " return input_batch, target_batch\n", + "\n", + "\n", + "class NNLM(nn.Module): #搭建一个NNLM语言模型\n", + " def __init__(self):\n", + " super(NNLM, self).__init__()\n", + " self.embed = nn.Embedding(n_class, m)\n", + " self.W = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n", + " self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n", + "\n", + " self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n", + " self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n", + "\n", + " def forward(self, x):\n", + " x = self.embed(x) # 4 x 2 x 2\n", + " x = x.view(-1, n_step * m)\n", + " tanh = torch.tanh(self.d + torch.mm(x, self.W)) # 4 x 2\n", + " output = self.b + torch.mm(tanh, self.U)\n", + " return output\n", + "\n", + "model = NNLM()\n", + "\n", + "criterion = nn.CrossEntropyLoss() #损失函数的设置\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001) #优化器的设置\n", + "\n", + "input_batch, target_batch = make_batch(corpus) #训练集和标签值\n", + "input_batch = Variable(torch.LongTensor(input_batch))\n", + "target_batch = Variable(torch.LongTensor(target_batch))\n", + "\n", + "for epoch in range(10000): #训练过程\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(input_batch) # input: 4 x 2\n", + "\n", + " loss = criterion(output, target_batch)\n", + "\n", + " if (epoch + 1) % 1000 == 0:\n", + " print('epoch:', '%04d' % (epoch + 1), 'cost = {:.6f}'.format(loss.item()))\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "predict = model(input_batch).data.max(1, keepdim=True)[1]#模型预测过程\n", + "\n", + "print([sen.split()[:2] for sen in corpus], '->', [number_dict[n.item()] for n in predict.squeeze()])" + ] + }, + { + "cell_type": "markdown", + "id": "93d8cd2f", + "metadata": {}, + "source": [ + "# Word2Vec模型:主要采用Skip-gram和Cbow两种模式\n", + " 前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。\n", + " skip-gram模型(跳字模型)是用中心词去预测周围词\n", + " cbow模型(连续词袋模型)是用周围词预测中心词" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "066f68a0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 11%|█ | 10615/100000 [00:02<00:24, 3657.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 10000 cost = 1.955088\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 21%|██ | 20729/100000 [00:05<00:21, 3758.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 20000 cost = 1.673096\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 30438/100000 [00:08<00:18, 3710.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 30000 cost = 2.247422\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 41%|████ | 40638/100000 [00:11<00:15, 3767.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 40000 cost = 2.289902\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 50486/100000 [00:13<00:13, 3713.98it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 50000 cost = 2.396217\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 61%|██████ | 60572/100000 [00:16<00:11, 3450.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 60000 cost = 1.539688\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 71%|███████ | 70638/100000 [00:19<00:07, 3809.11it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 70000 cost = 1.638879\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 80403/100000 [00:21<00:05, 3740.33it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 80000 cost = 2.279797\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 90%|█████████ | 90480/100000 [00:24<00:02, 3680.03it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 90000 cost = 1.992100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [00:27<00:00, 3677.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 100000 cost = 1.307715\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD6CAYAAACiefy7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnsUlEQVR4nO3de3hU1b3/8fc34U64KagRqQGLyC2BECSAXGysYKmAF4poRVFLU8VDbbViFZtqPW0PtB6pYkRBQFGOCoIoFn8IFDQIBAgIyr1RrhLBhATDJWT9/pghTUICCZnMTGY+r+fJk9lrr9nrO0PyYWftPXubcw4REQl9EYEuQERE/EOBLyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+FKtzCzGzDZVov9QM+tQnTWJhCsL5vPwmzdv7mJiYgJdhlTB8ePH2bFjBx07dqxQ/8zMTJo0aUKzZs2quTKR0LR27dpvnXMtylpXy9/FVEZMTAzp6emBLkOqIDMzkxtuuIGuXbuSlpZGy5YtmT9/Pq+//jpTpkzhxIkT/PCHP+S1114jIyODn/70pxQWFlJQUMCcOXMAeOCBB8jKyqJBgwa8/PLLXHXVVQF+VSLBy8y+Km+dpnSk2m3fvp0HHniAzZs307RpU+bMmcPNN9/MmjVr2LBhA+3bt2fq1Kn06tWLwYMHM2HCBDIyMrjiiisYPXo0//jHP1i7di0TJ07k/vvvD/TLEamxgnoPX0JD69at6dKlCwDdunUjMzOTTZs28cQTT5CdnU1eXh4DBgw443l5eXmkpaUxbNiworbjx4/7q2yRkKPAl2pXt27doseRkZHk5+dz9913M2/ePOLi4pg+fTrLli0743mFhYU0bdqUjIwM/xUrEsI0pSMBkZubS3R0NCdPnmTWrFlF7Y0aNSI3NxeAxo0b07p1a95++20AnHNs2LAhIPWKhAIFvgTE008/TY8ePfjxj39c4iDsbbfdxoQJE+jatSs7d+5k1qxZTJ06lbi4ODp27Mj8+fMDWLVIzRbUp2UmJCQ4naUTvuat38uERVvZl53PpU3r88iAdgzt2jLQZYkENTNb65xLKGud5vAlKM1bv5fH5n5O/slTAOzNzuexuZ8DKPRFzpOmdCQoTVi0tSjsT8s/eYoJi7YGqCKRmk+BL0FpX3Z+pdpF5NwU+BKULm1av1LtInJuCnwJSo8MaEf92pEl2urXjuSRAe0CVJFIzaeDthKUTh+Y1Vk6Ir6jwJegNbRrSwW8iA/5ZErHzKaZ2cHyrntuHpPMbIeZbTSzeF+MKyIiFeerOfzpwMCzrL8BaOv9Gg286KNxRUSkgnwS+M655cDhs3QZAsx0Hp8BTc0s2hdji4hIxfjrLJ2WwO5iy3u8bWcws9Fmlm5m6VlZWX4pTkQkHPgr8K2MtjIv4uOcm+KcS3DOJbRoUeZdukRE5Dz4K/D3AK2KLV8G7PPT2CIigv8C/z1gpPdsnUQgxzm3309ji4gIPjoP38zeBPoDzc1sD/AHoDaAcy4VWAj8BNgBfA+M8sW4IiJScT4JfOfciHOsd8ADvhhLRETOj66lIyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+BKSMjMz6dSp03k//+677+add97xYUUigafAFymloKAg0CWIVAsFvoSsgoIC7rrrLmJjY7n11lv5/vvveeqpp+jevTudOnVi9OjReD4iAv379+f3v/89/fr147nnniuxnfHjx3P33XdTWFgYiJch4jMKfAlZW7duZfTo0WzcuJHGjRszefJkxowZw5o1a9i0aRP5+fm8//77Rf2zs7P517/+xW9/+9uitt/97nccPHiQV199lYgI/bpIzaafYAlZrVq1onfv3gD8/Oc/55NPPmHp0qX06NGDzp07s2TJEjZv3lzUf/jw4SWe//TTT5Odnc1LL72EWVkXfBWpWXRPWwlZpUPazLj//vtJT0+nVatWpKSkcOzYsaL1DRs2LNG/e/furF27lsOHD3PBBRf4pWaR6qQ9fAlZX3/9NStXrgTgzTff5JprrgGgefPm5OXlnfMsnIEDBzJu3DgGDRpEbm5utdcrUt20hy8hq3379syYMYNf/vKXtG3bll/96ld89913dO7cmZiYGLp3737ObQwbNozc3FwGDx7MwoULqV+/vh8qF6kedvoshWCUkJDg0tPTA12GhJkPdn3Ac+ue48DRA1zS8BLGxo9lUJtBgS5LpELMbK1zLqGsddrDFynmg10fkJKWwrFTnrn9/Uf3k5KWAqDQlxpPc/gixTy37rmisD/t2KljPLfuuXKeIVJzKPBFijlw9ECl2kVqEgW+SDGXNLykUu0iNYkCX6SYsfFjqRdZr0Rbvch6jI0fG6CKRHxHgS9nlZqaysyZM32yrZiYGL799lufbKu6DGoziJReKUQ3jMYwohtGk9IrRQdsJSToLB05q+Tk5ECX4HeD2gxSwEtI8skevpkNNLOtZrbDzMaVsb6JmS0wsw1mttnMRvliXDk/Q4cOpVu3bnTs2JEpU6YAEBUVxeOPP05cXByJiYl88803AKSkpDBx4kTAc0XJhx56iL59+9K+fXvWrFnDzTffTNu2bXniiSfOun0RCbwqB76ZRQIvADcAHYARZtahVLcHgC+cc3FAf+BvZlanqmPL+Zk2bRpr164lPT2dSZMmcejQIY4ePUpiYiIbNmygb9++vPzyy2U+t06dOixfvpzk5GSGDBnCCy+8wKZNm5g+fTqHDh0qd/siEni+2MO/GtjhnNvlnDsBzAaGlOrjgEbmuZpVFHAY0F0mAmTSpElFe/K7d+9m+/bt1KlTh5/+9KcAdOvWjczMzDKfO3jwYAA6d+5Mx44diY6Opm7durRp04bdu3eXu/2aqlevXoEuQcRnfDGH3xLYXWx5D9CjVJ/ngfeAfUAjYLhzrsy7SZjZaGA0wA9+8AMflCfFLVu2jMWLF7Ny5UoaNGhA//79OXbsGLVr1y66umRkZGS5d32qW7cuABEREUWPTy8XFBSUu/2aKi0tLdAliPiML/bwy7pQeOkL9AwAMoBLgS7A82bWuKyNOeemOOcSnHMJLVq08EF5UlxOTg7NmjWjQYMGbNmyhc8++6xGbd/foqKiAl1CtSp+jEZCny8Cfw/QqtjyZXj25IsbBcx1HjuAfwNX+WBsqaSBAwdSUFBAbGws48ePJzExsUZtX0TOX5WvlmlmtYBtQBKwF1gD3O6c21ysz4vAN865FDO7GFgHxDnnznpStq6WWXMdXX+QI4syOZV9nMimdWk8IIaGXS8KdFmVFhUVRV5eXqDL8KlnnnmGmTNn0qpVK1q0aEG3bt247rrrSE5O5vvvv+eKK65g2rRpNGvWjDVr1nDvvffSsGFDrrnmGj788EM2bdoU6JcgZ3G2q2VWeQ/fOVcAjAEWAV8CbznnNptZspmdPon7aaCXmX0OfAw8eq6wl5rr6PqDZM/dzqns4wCcyj5O9tztHF1/MMCVydq1a5k9ezbr169n7ty5rFmzBoCRI0fy17/+lY0bN9K5c2f++Mc/AjBq1ChSU1NZuXIlkZGRgSxdfMAn5+E75xY65650zl3hnHvG25bqnEv1Pt7nnLveOdfZOdfJOfe6L8YV/5g+fTpjxoypcP8jizJxJ0sek3cnCzmyKNPHlUllrVixgptuuokGDRrQuHFjBg8ezNGjR8nOzqZfv34A3HXXXSxfvpzs7Gxyc3OLzlS6/fbbA1m6+IAurSA+d3rPvqLt4l8VvSF7MN8cSc6PAj/Elfep2t/+9rfEx8eTlJREVlYW4Pkk7a9//Wt69epFp06dWL169Rnby8rK4pZbbqF79+50796dTz/99Iw+kU3rntF2tvagsfEteLYTpDT1fN/4VsjN3/ft25d3332X/Px8cnNzWbBgAQ0bNqRZs2asWLECgNdee41+/frRrFkzGjVqVHSm1ezZswNZuviAAj/Elfep2vj4eNatW0e/fv2K5msBjh49SlpaGpMnT+aee+45Y3tjx47loYceYs2aNcyZM4f77rvvjD6NB8RgtUv+aFntCBoPiPH56/OZjW/Bgv+CnN2A83xf8F+e9hASHx/P8OHD6dKlC7fccgt9+vQBYMaMGTzyyCPExsaSkZHBk08+CcDUqVMZPXo0PXv2xDlHkyZNAlm+VJEunhbiJk2axLvvvgtQ9KnXiIgIhg8fDsDPf/5zbr755qL+I0aMADx7gkeOHCE7O7vE9hYvXswXX3xRtHzkyBFyc3Np1KhRUdvps3Fq1Fk6Hz8FJ/NLtp3M97TH/iwwNVWTxx9/nMcff/yM9tKfmdi4cSNLly7l5ptvpkmTJmRmZpKQUObJH1JDKPBDWEU/9Vp8Trf0/G7p5cLCQlauXEn9+vXPOnbDrhcFd8CXlrOncu0hbuPGjSxYsID169fzySefUFhYSLNmzXjppZcCXZpUgaZ0Qlh5n3otLCzknXfeAeCNN97gmmuuKXrO//3f/wHwySef0KRJkzP+hL/++ut5/vnni5YzMjKq+VX4SZPLKtce4j7++GNOnjxJp06dSE5O5v7772fEiBGsX78+0KVJFSjwQ1h5n3pt2LAhmzdvplu3bixZsqRovhagWbNm9OrVi+TkZKZOnXrGNidNmkR6ejqxsbF06NCB1NRUv72eapX0JNQu9VdL7fqe9jATFRVFTk4Oubm5vPWW5xhGRkYGCxcuJCcnJ8DVSVVU+ZO21UmftK0e5X16tH///kycOLHcedovVyxlxeyZ5B76lkYXNqfPbSNp3+fa6i7Xfza+5Zmzz9nj2bNPejLk5u8rIioqiqeffrpEuGdkZLBv3z5GjBjBQw89FMDq5Fyq9ZO2Eh6+XLGUj6Y8T+63WeAcud9m8dGU5/lyxdJAl+Y7sT+DhzZBSrbnexiG/WlJSUnk5eUxefLkoraIiAiSkpL44IMP6NmzJ99++y0fffQRPXv2JD4+nmHDhoXcaayhRoEfhsr7pVy2bFm5e/crZs+k4ETJD04VnDjOitm+ud+tBJfY2FiSkpKKLqdQv359YmJi2LlzJ3/5y19YuHAhAH/6059YvHgx69atIyEhgb///e+BLFvOQWfpSIXkHir70kfltUvN1759ey688EJSUlKYPn06EyZMIDMzk48++ojGjRvz/vvv88UXX9C7d28ATpw4Qc+ePQNctZyNAl8qpNGFzT3TOWW0S3ho06YNu3btYtu2bSQkJOCc48c//jFvvvlmoEuTCtKUjlRIn9tGUqtOyUsj1KpTlz63jQxQReJvl19+OXPnzmXkyJFs3ryZxMREPv30U3bs2AHA999/z7Zt2wJcpZyNAl8qpH2fa7l+9BgaNW8BZjRq3oLrR48JrbN05JzatWvHrFmzGDZsGEeOHGH69OmMGDGC2NhYEhMT2bJlS6BLlLPQaZkiUmn7D8xn186JHDu+n3p1o2lzxcNEXzIk0GUJZz8tU3P4IlIp+w/MZ8uWxyks9Fx76NjxfWzZ4rk2j0I/uGlKR0QqZdfOiUVhf1phYT67dupm6MFOgS8ilXLs+P5KtUvwUOCLSKXUqxtdqfbzkZmZyRtvvOGz7YmHAl9EKqXNFQ8TEVHyQnMREfVpc8XDPhtDgV89FPgiUinRlwzhqqueoV7dSwHju8MX8Iv7shn/xAI6derEHXfcweLFi+nduzdt27Zl9erVHD16lHvuuYfu3bvTtWtX5s+fD3iCvU+fPsTHxxMfH09aWhoA48aNY8WKFXTp0oVnn302gK82xDjnqvwFDAS2AjuAceX06Q9kAJuBf1Vku926dXMiEtz+/e9/u8jISLdx40Z36tQpFx8f70aNGuUKCwvdvHnz3JAhQ9xjjz3mXnvtNeecc999951r27aty8vLc0ePHnX5+fnOOee2bdvmTv/OL1261A0aNChgr6kmA9JdOZla5dMyzSwSeAH4MbAHWGNm7znnvijWpykwGRjonPvazGrQrZBE5Fxat25N586dAejYsSNJSUmYGZ07dyYzM5M9e/bw3nvvMXGi50yeY8eO8fXXX3PppZcyZswYMjIyiIyM1Cd1q5kvzsO/GtjhnNsFYGazgSHAF8X63A7Mdc59DeCcO+iDcUUkSNSt+5/LbkRERBQtR0REUFBQQGRkJHPmzKFdu3YlnpeSksLFF1/Mhg0bKCwspF69en6tO9z4Yg6/JbC72PIeb1txVwLNzGyZma01s3IvwGJmo80s3czSs7LOvFiXiNQ8AwYM4B//+Mfp6d2iWyXm5OQQHR1NREQEr732GqdOnQKgUaNG5ObmBqxeX5g+fTpjxowJdBkl+CLwrYy20tdrqAV0AwYBA4DxZnZlWRtzzk1xziU45xJatGjhg/JEJNDGjx/PyZMniY2NpVOnTowfPx6A+++/nxkzZpCYmMi2bdto2LAh4Lkef61atYiLiwvYQVvnHIWFhQEZu7pU+Vo6ZtYTSHHODfAuPwbgnPtzsT7jgHrOuRTv8lTgn865t8+27bKupTNp0iRefPFFDhw4wKOPPsq4cePKfO706dNJT08vccNtEQle89bvZcKirezLzufSpvV5ZEA7hnYtPVlQvTIzM7nhhhu49tprWblyJUOHDuX999/n+PHj3HTTTfzxj38EYOjQoezevZtjx44xduxYRo8eDcCrr77Kn//8Z6Kjo7nyyiupW7eu3zOouq+lswZoa2atgb3AbXjm7IubDzxvZrWAOkAP4Lz+2548eTIffvghrVu3rkLJIhJM5q3fy2NzPyf/pGdKZ292Po/N/RzA76G/detWXn31VYYOHco777zD6tWrcc4xePBgli9fTt++fZk2bRoXXHAB+fn5dO/enVtuuYUTJ07whz/8gbVr19KkSROuvfZaunbt6tfaz6XKUzrOuQJgDLAI+BJ4yzm32cySzSzZ2+dL4J/ARmA18IpzblNlx0pOTmbXrl0MHjyYZ599tmh+7O2336ZTp07ExcXRt2/fov779u1j4MCBtG3blt/97ndVfakiUk0mLNpaFPan5Z88xYRFW/1ey+WXX05iYiIfffQRH330EV27diU+Pp4tW7awfft2wDPTEBcXR2JiIrt372b79u2sWrWK/v3706JFC+rUqcPw4cP9Xvu5+ORqmc65hcDCUm2ppZYnABOqMk5qair//Oc/Wbp0Ke+//35R+1NPPcWiRYto2bIl2dnZRe0ZGRmsX7+eunXr0q5dOx588EFatWpVlRJEpBrsy86vVHt1On0cwTnHY489xi9/+csS65ctW8bixYtZuXIlDRo0oH///hw7dgwAs7IOaQaPkPikbe/evbn77rt5+eWXi47yAyQlJdGkSRPq1atHhw4d+OqrrwJYpYiU59Km9SvV7g8DBgxg2rRp5OXlAbB3714OHjxITk4OzZo1o0GDBmzZsoXPPvsMgB49erBs2TIOHTrEyZMnefvtsx6iDIiQuB5+amoqq1at4oMPPqBLly5kZGQAJc8NjoyMpKCgIEAVisjZPDKgXYk5fID6tSN5ZEC7szyrel1//fV8+eWXRTdmj4qK4vXXX2fgwIGkpqYSGxtLu3btSExMBCA6OpqUlBR69uxJdHQ08fHxJXZAg0FIBP7OnTvp0aMHPXr0YMGCBezevfvcTxKRoHH6wGygz9KJiYlh06b/HF4cO3YsY8eOPaPfhx9+eEZbzoIFXPPa68yPrEWtyFpcdN11NLnxxmqtt7JCIvAfeeQRtm/fjnOOpKQk4uLiivbyRaRmGNq1pd8D3ldyFixg//gncd65/IJ9+9g//kmAoAp93dNWRKSKtv8oiYJ9+85or3XppbRd8rFfawm7e9rOOXCYP+/az97jJ2lZtzaPtYnmlksuCHRZIhKiCvaXfbev8toDJSTO0iluzoHDPLx1N3uOn8QBe46f5OGtu5lz4HCgSxOREFUruuy7fZXXHighF/h/3rWf/MKS01T5hY4/7wqu/2lFJHRc9NCvsVJX+rR69bjooV8HpqByhNyUzt7jJyvVLiJSVacPzB589n8p2L+fWtHRXPTQr4PqgC2EYOC3rFubPWWEe8u6tQNQjYiEiyY33hh0AV9ayE3pPNYmmvoRJT/eXD/CeKxNcM2liYj4W8jt4Z8+G0dn6YiIlBRygQ+e0FfAi4iUFHJTOiIiUjYFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGzcWfp1N7NTZnarL8YVEZGKq3Lgm1kk8AJwA9ABGGFmHcrp91dgUVXHFBGRyvPFHv7VwA7n3C7n3AlgNjCkjH4PAnOAgz4YU0REKskXgd8S2F1seY+3rYiZtQRuAlLPtTEzG21m6WaWnpWV5YPyREQEfBP4VkZb6Tuj/y/wqHPu1Lk25pyb4pxLcM4ltGjRwgfliYgI+OZqmXuAVsWWLwNK3749AZhtZgDNgZ+YWYFzbp4PxpcaIDU1lQYNGjBy5MhAlyIStnwR+GuAtmbWGtgL3AbcXryDc6716cdmNh14X2EfXpKTkwNdgkjYq/KUjnOuABiD5+ybL4G3nHObzSzZzPRbXgNlZmZy1VVXcd9999GpUyfuuOMOFi9eTO/evWnbti2rV6/m8OHDDB06lNjYWBITE9m4cSOFhYXExMSQnZ1dtK0f/vCHfPPNN6SkpDBx4kQAdu7cycCBA+nWrRt9+vRhy5YtAXqlIuHFJzdAcc4tBBaWaivzAK1z7m5fjCnVa8eOHbz99ttMmTKF7t2788Ybb/DJJ5/w3nvv8d///d+0atWKrl27Mm/ePJYsWcLIkSPJyMhgyJAhvPvuu4waNYpVq1YRExPDxRdfXGLbo0ePJjU1lbZt27Jq1Sruv/9+lixZEqBXKhI+QvKOV1J1rVu3pnPnzgB07NiRpKQkzIzOnTuTmZnJV199xZw5cwD40Y9+xKFDh8jJyWH48OE89dRTjBo1itmzZzN8+PAS283LyyMtLY1hw4YVtR0/ftx/L0wkjCnwpUx169YtehwREVG0HBERQUFBAbVqnfmjY2b07NmTHTt2kJWVxbx583jiiSdK9CksLKRp06ZkZGRUa/0iciZdS0fOS9++fZk1axYAy5Yto3nz5jRu3Bgz46abbuI3v/kN7du358ILLyzxvMaNG9O6dWvefvttAJxzbNiwwe/1i4QjBb6cl5SUFNLT04mNjWXcuHHMmDGjaN3w4cN5/fXXz5jOOW3WrFlMnTqVuLg4OnbsyPz58/1VtkhYM+dKf0YqeCQkJLj09PRAlyE+9MGuD3hu3XMcOHqASxpewtj4sQxqMyjQZYmEDDNb65xLKGud5vDFbz7Y9QEpaSkcO3UMgP1H95OSlgKg0BfxA03piN88t+65orA/7dipYzy37rkAVSQSXhT44jcHjh6oVLuI+JYCX/zmkoaXVKpdRHxLgS9+MzZ+LPUi65VoqxdZj7HxYwNUkUh40UFb8ZvTB2Z1lo5IYCjwxa8GtRmkgBcJEE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJnwS+mQ00s61mtsPMxpWx/g4z2+j9SjOzOF+MKyIiFVflwDezSOAF4AagAzDCzDqU6vZvoJ9zLhZ4GphS1XFFRKRyfLGHfzWwwzm3yzl3ApgNDCnewTmX5pz7zrv4GXCZD8YVEQlZy5YtIy0tzafb9EXgtwR2F1ve420rz73Ah+WtNLPRZpZuZulZWVk+KE9EpOYJ1sC3MtrKvDO6mV2LJ/AfLW9jzrkpzrkE51xCixYtfFCeiEjwmDlzJrGxscTFxXHnnXeyYMECevToQdeuXbnuuuv45ptvyMzMJDU1lWeffZYuXbqwYsUKn4zti8sj7wFaFVu+DNhXupOZxQKvADc45w75YFwRkRpl8+bNPPPMM3z66ac0b96cw4cPY2Z89tlnmBmvvPIK//M//8Pf/vY3kpOTiYqK4uGHH/bZ+L4I/DVAWzNrDewFbgNuL97BzH4AzAXudM5t88GYIiI1zpIlS7j11ltp3rw5ABdccAGff/45w4cPZ//+/Zw4cYLWrVtX2/hVntJxzhUAY4BFwJfAW865zWaWbGbJ3m5PAhcCk80sw8zSqzquiEhN45zDrOQs+IMPPsiYMWP4/PPPeemllzh27Fi1je+T8/Cdcwudc1c6565wzj3jbUt1zqV6H9/nnGvmnOvi/UrwxbgiIjVJUlISb731FocOeWa1Dx8+TE5ODi1bes5zmTFjRlHfRo0akZub69Px9UlbERE/6dixI48//jj9+vUjLi6O3/zmN6SkpDBs2DD69OlTNNUDcOONN/Luu+/69KCtOVfmCTVBISEhwaWna/ZHRMLDtlUHWDl/J3mHjxN1QV16DrmCK3tcUqltmNna8mZRdBNzEZEgsG3VAZbO2kLBiUIA8g4fZ+msLQCVDv3yaEpHRCQIrJy/syjsTys4UcjK+Tt9NoYCX0QkCOQdPl6p9vOhwBcRCQJRF9StVPv5UOCLiASBnkOuoFadkpFcq04EPYdc4bMxdNBWRCQInD4wW9WzdM5GgS8iEiSu7HGJTwO+NE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGxcGevNzCZ51280s3hfjCsiIhVX5cA3s0jgBeAGoAMwwsw6lOp2A9DW+zUaeLGq44qISOX4Yg//amCHc26Xc+4EMBsYUqrPEGCm8/gMaGpm0T4YW0REKsgXgd8S2F1seY+3rbJ9ADCz0WaWbmbpWVlZPihPRETAN4FvZbS58+jjaXRuinMuwTmX0KJFiyoXJyIiHr4I/D1Aq2LLlwH7zqOPiIhUI18E/hqgrZm1NrM6wG3Ae6X6vAeM9J6tkwjkOOf2+2BsERGpoCrf09Y5V2BmY4BFQCQwzTm32cySvetTgYXAT4AdwPfAqKqOKyIileOTm5g75xbiCfXibanFHjvgAV+MJSIi50eftBURCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMFGlwDezC8zs/5nZdu/3ZmX0aWVmS83sSzPbbGZjqzKmiIicn6ru4Y8DPnbOtQU+9i6XVgD81jnXHkgEHjCzDlUcV0REKqmqgT8EmOF9PAMYWrqDc26/c26d93Eu8CXQsorjiohIJVU18C92zu0HT7ADF52ts5nFAF2BVWfpM9rM0s0sPSsrq9IF9erVq9LPEREJB7XO1cHMFgOXlLHq8coMZGZRwBzg1865I+X1c85NAaYAJCQkuMqMAZCWllbZp4iIhIVzBr5z7rry1pnZN2YW7Zzbb2bRwMFy+tXGE/aznHNzz7vaCoiKiiIvL4/9+/czfPhwjhw5QkFBAS+++CJ9+vSpzqFFRIJaVad03gPu8j6+C5hfuoOZGTAV+NI59/cqjldhb7zxBgMGDCAjI4MNGzbQpUsXfw0tIhKUzrmHfw5/Ad4ys3uBr4FhAGZ2KfCKc+4nQG/gTuBzM8vwPu/3zrmFVRz7rLp3784999zDyZMnGTp0qAJfRMJelfbwnXOHnHNJzrm23u+Hve37vGGPc+4T55w552Kdc128X9Ua9gB9+/Zl+fLltGzZkjvvvJOZM2dW95AiIkEtZD9p+9VXX3HRRRfxi1/8gnvvvZd169YFuiQRkYCq6pRO0Fq2bBkTJkygdu3aREVFaQ9fRMJeyAV+Xto0eLYTd+Xs4a57L4OkxyD2Z4EuS0Qk4EIr8De+BQv+C07me5ZzdnuWQaEvImEvtObwP37qP2F/2sl8T7uISJgLrcDP2VO5dhGRMBJagd/kssq1i4iEkdAK/KQnoXb9km2163vaRUTCXGgFfuzP4MZJ0KQVYJ7vN07SAVsREULtLB3whLsCXkTkDKG1hy8iIuVS4IuIhAkFvohImFDgi4iECQW+iEiYMOcqfdtYvzGzLOArH22uOfCtj7blS6qrclRX5QRrXRC8tdX0ui53zrUoa0VQB74vmVm6cy4h0HWUproqR3VVTrDWBcFbWyjXpSkdEZEwocAXEQkT4RT4UwJdQDlUV+WorsoJ1rogeGsL2brCZg5fRCTchdMevohIWFPgi4iEiZAKfDMbaGZbzWyHmY0rY/0dZrbR+5VmZnFBVNsQb10ZZpZuZtcEQ13F+nU3s1Nmdmsw1GVm/c0sx/t+ZZiZX256UJH3y1tbhpltNrN/BUNdZvZIsfdqk/ff8oIgqKuJmS0wsw3e92tUdddUwbqamdm73t/J1WbWyU91TTOzg2a2qZz1ZmaTvHVvNLP4Sg3gnAuJLyAS2Am0AeoAG4AOpfr0App5H98ArAqi2qL4zzGVWGBLMNRVrN8SYCFwazDUBfQH3g/Cn7GmwBfAD7zLFwVDXaX63wgsCYa6gN8Df/U+bgEcBuoEQV0TgD94H18FfOynn7G+QDywqZz1PwE+BAxIrGyGhdIe/tXADufcLufcCWA2MKR4B+dcmnPuO+/iZ4C/7n1YkdrynPdfFGgI+ONo+jnr8noQmAMc9ENNlanL3ypS1+3AXOfc1wDOOX+8Z5V9v0YAbwZJXQ5oZGaGZ6fnMFAQBHV1AD4GcM5tAWLM7OJqrgvn3HI870F5hgAzncdnQFMzi67o9kMp8FsCu4st7/G2ledePP9T+kOFajOzm8xsC/ABcE8w1GVmLYGbgFQ/1FPhurx6eqcCPjSzjkFS15VAMzNbZmZrzWxkkNQFgJk1AAbi+Q88GOp6HmgP7AM+B8Y65wqDoK4NwM0AZnY1cDn+20E8m8rmXAmhFPhWRluZe8lmdi2ewH+0WisqNmQZbWfU5px71zl3FTAUeLq6i6Jidf0v8Khz7lT1l1OkInWtw3PNkDjgH8C86i6KitVVC+gGDAIGAOPN7MogqOu0G4FPnXNn24v0lYrUNQDIAC4FugDPm1nj6i2rQnX9Bc9/3Bl4/sJdT/X/5VERlfm3PkMo3eJwD9Cq2PJlePYaSjCzWOAV4Abn3KFgqu0059xyM7vCzJo756rzIk4VqSsBmO35i5vmwE/MrMA5Ny+QdTnnjhR7vNDMJgfJ+7UH+NY5dxQ4ambLgThgW4DrOu02/DOdAxWraxTwF+905g4z+zeeOfPVgazL+/M1CjwHSoF/e78CrVJZcgZ/HIjw08GOWsAuoDX/ORDTsVSfHwA7gF5BWNsP+c9B23hg7+nlQNZVqv90/HPQtiLv1yXF3q+rga+D4f3CMz3xsbdvA2AT0CnQdXn7NcEzP9ywuv8NK/F+vQikeB9f7P25bx4EdTXFe/AY+AWeefNqf8+848VQ/kHbQZQ8aLu6MtsOmT1851yBmY0BFuE5Cj/NObfZzJK961OBJ4ELgcnePdYC54er4lWwtluAkWZ2EsgHhjvvv3CA6/K7CtZ1K/ArMyvA837dFgzvl3PuSzP7J7ARKARecc6VeYqdP+vydr0J+Mh5/vqodhWs62lgupl9jifEHnXV+1daRetqD8w0s1N4zrq6tzprOs3M3sRzBlpzM9sD/AGoXayuhXjO1NkBfI/3r5AKb7+af0dERCRIhNJBWxEROQsFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhIn/DyCri3Zc6/JlAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "打印\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.autograd import variable\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from tqdm import tqdm\n", + "\n", + "dtype = torch.FloatTensor\n", + "#我们使用的语料库 \n", + "sentences = ['i like dog','i like cat','i like animal','dog is animal','cat is animal',\n", + " 'dog like meat','cat like meat','cat like fish','dog like meat','i like apple',\n", + " 'i hate apple','i like movie','i like read','dog like bark','dog like cat']\n", + "\n", + "\n", + "\n", + "word_sequence = ' '.join(sentences).split() #将语料库的每一句话的每一个词转化为列表 \n", + "#print(word_sequence)\n", + "\n", + "word_list = list(set(word_sequence)) #构建我们的词表 \n", + "#print(word_list)\n", + "\n", + "#word_voc = list(set(word_sequence)) \n", + "\n", + "#接下来对此表中的每一个词编号 这就用到了我们之前提到的one-hot编码 \n", + "\n", + "#词典 词对应着编号\n", + "word_dict = {w:i for i,w in enumerate(word_list)}\n", + "#print(word_dict)\n", + "#编号对应着词\n", + "index_dict = {i:w for w,i in enumerate(word_list)}\n", + "#print(index_dict)\n", + "\n", + "\n", + "batch_size = 2\n", + "voc_size = len(word_list)\n", + "\n", + "skip_grams = []\n", + "for i in range(1,len(word_sequence)-1,3):\n", + " target = word_dict[word_sequence[i]] #当前词对应的id\n", + " context = [word_dict[word_sequence[i-1]],word_dict[word_sequence[i+1]]] #两个上下文词对应的id\n", + "\n", + " for w in context:\n", + " skip_grams.append([target,w])\n", + "\n", + "embedding_size = 10 \n", + "\n", + "\n", + "class Word2Vec(nn.Module):\n", + " def __init__(self):\n", + " super(Word2Vec,self).__init__()\n", + " self.W1 = nn.Parameter(torch.rand(len(word_dict),embedding_size)).type(dtype) \n", + " #将词的one-hot编码对应到词向量中\n", + " self.W2 = nn.Parameter(torch.rand(embedding_size,voc_size)).type(dtype)\n", + " #将词向量 转化为 输出 \n", + " def forward(self,x):\n", + " hidden_layer = torch.matmul(x,self.W1)\n", + " output_layer = torch.matmul(hidden_layer,self.W2)\n", + " return output_layer\n", + "\n", + "\n", + "model = Word2Vec()\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(),lr=1e-5)\n", + "\n", + "#print(len(skip_grams))\n", + "#训练函数\n", + "\n", + "def random_batch(data,size):\n", + " random_inputs = []\n", + " random_labels = []\n", + " random_index = np.random.choice(range(len(data)),size,replace=False)\n", + " \n", + " for i in random_index:\n", + " random_inputs.append(np.eye(voc_size)[data[i][0]]) #从一个单位矩阵生成one-hot表示\n", + " random_labels.append(data[i][1])\n", + " \n", + " return random_inputs,random_labels\n", + "\n", + "for epoch in tqdm(range(100000)):\n", + " input_batch,target_batch = random_batch(skip_grams,batch_size) # X -> y\n", + " input_batch = torch.Tensor(input_batch)\n", + " target_batch = torch.LongTensor(target_batch)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(input_batch)\n", + "\n", + " loss = criterion(output,target_batch)\n", + " if((epoch+1)%10000==0):\n", + " print(\"epoch:\",\"%04d\" %(epoch+1),'cost =' ,'{:.6f}'.format(loss))\n", + "\n", + " loss.backward() \n", + " optimizer.step()\n", + "\n", + "for i , label in enumerate(word_list):\n", + " W1,_ = model.parameters()\n", + " x,y = float(W1[i][0]),float(W1[i][1])\n", + " plt.scatter(x,y)\n", + " plt.annotate(label,xy=(x,y),xytext=(5,2),textcoords='offset points',ha='right',va='bottom')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1edccf25", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9 (default, Aug 31 2020, 12:42:55) \n[GCC 7.3.0]" + }, + "vscode": { + "interpreter": { + "hash": "7648c2b9d25760d0d65f53f9b9a34de48caa24d8265d64b0ff81e2f2641d528d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a399dea09f46145a5486fedbb41c200e0438ccd5 Mon Sep 17 00:00:00 2001 From: ZHOUhuichi <96902323+ZHOUhuichi@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:33:47 +0800 Subject: [PATCH 2/6] =?UTF-8?q?Create=20=E5=91=A8=E8=BE=89=E6=B1=A0NLP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../\345\221\250\350\276\211\346\261\240NLP" | 1 + 1 file changed, 1 insertion(+) create mode 100644 "docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" diff --git "a/docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" "b/docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ "b/docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" @@ -0,0 +1 @@ + From f43674db2c361ef62a15d192419a40be9735bc50 Mon Sep 17 00:00:00 2001 From: ZHOUhuichi <96902323+ZHOUhuichi@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:35:26 +0800 Subject: [PATCH 3/6] Add files via upload --- .../zhcNLP/NLP\345\237\272\347\241\200.ipynb" | 687 ++++++++++++++++++ .../zhcNLP/NLP\345\237\272\347\241\200.md" | 434 +++++++++++ .../zhcNLP/output_16_20.png" | Bin 0 -> 10308 bytes 3 files changed, 1121 insertions(+) create mode 100644 "docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.ipynb" create mode 100644 "docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" create mode 100644 "docs/\347\254\254\345\215\201\347\253\240/zhcNLP/output_16_20.png" diff --git "a/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.ipynb" "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.ipynb" new file mode 100644 index 000000000..ba5bcc1f9 --- /dev/null +++ "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.ipynb" @@ -0,0 +1,687 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "119ec186", + "metadata": {}, + "source": [ + "# 词嵌入(概念部分)" + ] + }, + { + "cell_type": "markdown", + "id": "f8e5639e", + "metadata": {}, + "source": [ + "###   在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入?
\n", + " 计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。
\n", + " 我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1]
\n", + " 例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
\n", + "\n", + "
一个小朋友刚学会了四个字和词-->[我] [特别] [喜欢] [学习]
\n", + " 我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码
\n", + " 对于中文 我们先进行分词 我 特别 喜欢 学习
\n", + " 那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1]
\n", + " 现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1]

\n", + " 我们可以思考几个问题:
\n", + " 1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。
\n", + " 2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。
\n", + " 因此我们认为独热编码不是一个很好的词嵌入方法。
\n", + "\n", + " 我们再来介绍一下 **稠密表示**
\n", + "\n", + " 稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97]" + ] + }, + { + "cell_type": "markdown", + "id": "4db86da3", + "metadata": {}, + "source": [ + "# Bag of Words词袋表示" + ] + }, + { + "cell_type": "markdown", + "id": "44dc9252", + "metadata": {}, + "source": [ + "  词袋表示顾名思义,我们往一个袋子中装入我们的词汇,构成一个词袋,当我们想表达的时候,我们将其取出,构建词袋的方法可以有如下形式。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "823f8f2d", + "metadata": {}, + "outputs": [], + "source": [ + "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]#我们的语料库\n", + "word_list = ' '.join(corpus).split()\n", + "word_list = list(sorted(set(word_list)))\n", + "word_dict = {w: i for i, w in enumerate(word_list)}\n", + "number_dict = {i: w for i, w in enumerate(word_list)}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8eaeb37d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'do': 0,\n", + " 'drinking': 1,\n", + " 'hate': 2,\n", + " 'i': 3,\n", + " 'like': 4,\n", + " 'love': 5,\n", + " 'nlp': 6,\n", + " 'playing': 7,\n", + " 'reading': 8}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "word_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2bf380c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: 'do',\n", + " 1: 'drinking',\n", + " 2: 'hate',\n", + " 3: 'i',\n", + " 4: 'like',\n", + " 5: 'love',\n", + " 6: 'nlp',\n", + " 7: 'playing',\n", + " 8: 'reading'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "number_dict" + ] + }, + { + "cell_type": "markdown", + "id": "90e0ef43", + "metadata": {}, + "source": [ + " 根据如上形式,我们可以构建一个维度为9的one-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9821ed2a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "voc_size = len(word_dict)\n", + "bow = []\n", + "for i,name in enumerate(word_dict):\n", + " bow.append(np.eye(voc_size)[word_dict[name]])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "03f1f12f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([1., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 1., 0., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 1., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 1., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 1., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 1., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 1., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 0., 1., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 0., 0., 1.])]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bow" + ] + }, + { + "cell_type": "markdown", + "id": "086a5fd2", + "metadata": {}, + "source": [ + "# N-gram:基于统计的语言模型\n", + " N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。\n", + "\n", + " N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 \n", + "\n", + "**N-gram 模型的优点:**\n", + "\n", + "简单易用,N-gram 模型的概念非常简单,实现起来也很容易。 \n", + "能够捕捉到语言中的相关性,N-gram 模型通过考虑连续出现的 n 个词语来预测下一个词语,因此它能够捕捉到语言中词语之间的相关性。 \n", + "可以使用已有的语料库进行训练,N-gram 模型可以使用已有的大量语料库进行训练,例如 Google 的 N-gram 数据库,这样可以大大提高模型的准确性。 \n", + "\n", + "**N-gram 模型的缺点:**\n", + "\n", + "对于短文本数据集不适用,N-gram 模型需要大量的文本数据进行训练,因此对于短文本数据集可能无法达到较高的准确性。 \n", + "容易受到噪声和语言纠错的影响,N-gram 模型是基于语料库进行训练的,如果语料库中包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。 \n", + "无法捕捉到语言中的非线性关系,N-gram 模型假设语言中的关系是线性的,但事实上语言中可能存在复杂的非线性关系,N-gram 模型无法捕捉到这些关系。" + ] + }, + { + "cell_type": "markdown", + "id": "1f5ad65b", + "metadata": {}, + "source": [ + "# NNLM:前馈神经网络语言模型\n", + " 下面通过前馈神经网络模型来**展示滑动**窗口的使用" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7bddfa77", + "metadata": {}, + "outputs": [], + "source": [ + "#导入必要的库\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from tqdm import tqdm\n", + "from torch.autograd import Variable\n", + "dtype = torch.FloatTensor" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "29f23588", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i',\n", + " 'like',\n", + " 'reading',\n", + " 'i',\n", + " 'love',\n", + " 'drinking',\n", + " 'i',\n", + " 'hate',\n", + " 'playing',\n", + " 'i',\n", + " 'do',\n", + " 'nlp']" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]\n", + "\n", + "word_list = ' '.join(corpus).split()\n", + "word_list" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "12b58886", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 1000 cost = 1.010682\n", + "epoch: 2000 cost = 0.695155\n", + "epoch: 3000 cost = 0.597085\n", + "epoch: 4000 cost = 0.531892\n", + "epoch: 5000 cost = 0.376044\n", + "epoch: 6000 cost = 0.118038\n", + "epoch: 7000 cost = 0.077081\n", + "epoch: 8000 cost = 0.053636\n", + "epoch: 9000 cost = 0.038089\n", + "epoch: 10000 cost = 0.027224\n", + "[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'do']] -> ['studying', 'datawhale', 'playing', 'nlp']\n" + ] + } + ], + "source": [ + "#构建我们需要的语料库\n", + "corpus = [\"i like studying\", \"i love datawhale\", \"i hate playing\", \"i do nlp\"]\n", + "\n", + "word_list = ' '.join(corpus).split() #将语料库转化为一个个单词 ,如['i', 'like', 'reading', 'i', ...,'nlp']\n", + "word_list = list(sorted(set(word_list))) #用set去重后转化为链表\n", + "# print(word_list)\n", + "\n", + "word_dict = {w: i for i, w in enumerate(word_list)} #将词表转化为字典 这边是词对应到index\n", + "number_dict = {i: w for i, w in enumerate(word_list)}#这边是index对应到词\n", + "# print(word_dict)\n", + "# print(number_dict)\n", + "\n", + "n_class = len(word_dict) #计算出我们词表的大小,用于后面词向量的构建\n", + "\n", + "m = 2 #词嵌入维度\n", + "n_step = 2 #滑动窗口的大小\n", + "n_hidden = 2 #隐藏层的维度为2\n", + "\n", + "\n", + "def make_batch(sentence): #由于语料库较小,我们象征性将训练集按照批次处理 \n", + " input_batch = []\n", + " target_batch = []\n", + "\n", + " for sen in sentence:\n", + " word = sen.split()\n", + " input = [word_dict[n] for n in word[:-1]]\n", + " target = word_dict[word[-1]]\n", + "\n", + " input_batch.append(input)\n", + " target_batch.append(target)\n", + "\n", + " return input_batch, target_batch\n", + "\n", + "\n", + "class NNLM(nn.Module): #搭建一个NNLM语言模型\n", + " def __init__(self):\n", + " super(NNLM, self).__init__()\n", + " self.embed = nn.Embedding(n_class, m)\n", + " self.W = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n", + " self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n", + "\n", + " self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n", + " self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n", + "\n", + " def forward(self, x):\n", + " x = self.embed(x) # 4 x 2 x 2\n", + " x = x.view(-1, n_step * m)\n", + " tanh = torch.tanh(self.d + torch.mm(x, self.W)) # 4 x 2\n", + " output = self.b + torch.mm(tanh, self.U)\n", + " return output\n", + "\n", + "model = NNLM()\n", + "\n", + "criterion = nn.CrossEntropyLoss() #损失函数的设置\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001) #优化器的设置\n", + "\n", + "input_batch, target_batch = make_batch(corpus) #训练集和标签值\n", + "input_batch = Variable(torch.LongTensor(input_batch))\n", + "target_batch = Variable(torch.LongTensor(target_batch))\n", + "\n", + "for epoch in range(10000): #训练过程\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(input_batch) # input: 4 x 2\n", + "\n", + " loss = criterion(output, target_batch)\n", + "\n", + " if (epoch + 1) % 1000 == 0:\n", + " print('epoch:', '%04d' % (epoch + 1), 'cost = {:.6f}'.format(loss.item()))\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "predict = model(input_batch).data.max(1, keepdim=True)[1]#模型预测过程\n", + "\n", + "print([sen.split()[:2] for sen in corpus], '->', [number_dict[n.item()] for n in predict.squeeze()])" + ] + }, + { + "cell_type": "markdown", + "id": "93d8cd2f", + "metadata": {}, + "source": [ + "# Word2Vec模型:主要采用Skip-gram和Cbow两种模式\n", + " 前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。\n", + " skip-gram模型(跳字模型)是用中心词去预测周围词\n", + " cbow模型(连续词袋模型)是用周围词预测中心词" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "066f68a0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 11%|█ | 10615/100000 [00:02<00:24, 3657.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 10000 cost = 1.955088\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 21%|██ | 20729/100000 [00:05<00:21, 3758.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 20000 cost = 1.673096\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 30438/100000 [00:08<00:18, 3710.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 30000 cost = 2.247422\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 41%|████ | 40638/100000 [00:11<00:15, 3767.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 40000 cost = 2.289902\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 50486/100000 [00:13<00:13, 3713.98it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 50000 cost = 2.396217\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 61%|██████ | 60572/100000 [00:16<00:11, 3450.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 60000 cost = 1.539688\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 71%|███████ | 70638/100000 [00:19<00:07, 3809.11it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 70000 cost = 1.638879\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 80403/100000 [00:21<00:05, 3740.33it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 80000 cost = 2.279797\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 90%|█████████ | 90480/100000 [00:24<00:02, 3680.03it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 90000 cost = 1.992100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [00:27<00:00, 3677.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 100000 cost = 1.307715\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD6CAYAAACiefy7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnsUlEQVR4nO3de3hU1b3/8fc34U64KagRqQGLyC2BECSAXGysYKmAF4poRVFLU8VDbbViFZtqPW0PtB6pYkRBQFGOCoIoFn8IFDQIBAgIyr1RrhLBhATDJWT9/pghTUICCZnMTGY+r+fJk9lrr9nrO0PyYWftPXubcw4REQl9EYEuQERE/EOBLyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+FKtzCzGzDZVov9QM+tQnTWJhCsL5vPwmzdv7mJiYgJdhlTB8ePH2bFjBx07dqxQ/8zMTJo0aUKzZs2quTKR0LR27dpvnXMtylpXy9/FVEZMTAzp6emBLkOqIDMzkxtuuIGuXbuSlpZGy5YtmT9/Pq+//jpTpkzhxIkT/PCHP+S1114jIyODn/70pxQWFlJQUMCcOXMAeOCBB8jKyqJBgwa8/PLLXHXVVQF+VSLBy8y+Km+dpnSk2m3fvp0HHniAzZs307RpU+bMmcPNN9/MmjVr2LBhA+3bt2fq1Kn06tWLwYMHM2HCBDIyMrjiiisYPXo0//jHP1i7di0TJ07k/vvvD/TLEamxgnoPX0JD69at6dKlCwDdunUjMzOTTZs28cQTT5CdnU1eXh4DBgw443l5eXmkpaUxbNiworbjx4/7q2yRkKPAl2pXt27doseRkZHk5+dz9913M2/ePOLi4pg+fTrLli0743mFhYU0bdqUjIwM/xUrEsI0pSMBkZubS3R0NCdPnmTWrFlF7Y0aNSI3NxeAxo0b07p1a95++20AnHNs2LAhIPWKhAIFvgTE008/TY8ePfjxj39c4iDsbbfdxoQJE+jatSs7d+5k1qxZTJ06lbi4ODp27Mj8+fMDWLVIzRbUp2UmJCQ4naUTvuat38uERVvZl53PpU3r88iAdgzt2jLQZYkENTNb65xLKGud5vAlKM1bv5fH5n5O/slTAOzNzuexuZ8DKPRFzpOmdCQoTVi0tSjsT8s/eYoJi7YGqCKRmk+BL0FpX3Z+pdpF5NwU+BKULm1av1LtInJuCnwJSo8MaEf92pEl2urXjuSRAe0CVJFIzaeDthKUTh+Y1Vk6Ir6jwJegNbRrSwW8iA/5ZErHzKaZ2cHyrntuHpPMbIeZbTSzeF+MKyIiFeerOfzpwMCzrL8BaOv9Gg286KNxRUSkgnwS+M655cDhs3QZAsx0Hp8BTc0s2hdji4hIxfjrLJ2WwO5iy3u8bWcws9Fmlm5m6VlZWX4pTkQkHPgr8K2MtjIv4uOcm+KcS3DOJbRoUeZdukRE5Dz4K/D3AK2KLV8G7PPT2CIigv8C/z1gpPdsnUQgxzm3309ji4gIPjoP38zeBPoDzc1sD/AHoDaAcy4VWAj8BNgBfA+M8sW4IiJScT4JfOfciHOsd8ADvhhLRETOj66lIyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+BKSMjMz6dSp03k//+677+add97xYUUigafAFymloKAg0CWIVAsFvoSsgoIC7rrrLmJjY7n11lv5/vvveeqpp+jevTudOnVi9OjReD4iAv379+f3v/89/fr147nnniuxnfHjx3P33XdTWFgYiJch4jMKfAlZW7duZfTo0WzcuJHGjRszefJkxowZw5o1a9i0aRP5+fm8//77Rf2zs7P517/+xW9/+9uitt/97nccPHiQV199lYgI/bpIzaafYAlZrVq1onfv3gD8/Oc/55NPPmHp0qX06NGDzp07s2TJEjZv3lzUf/jw4SWe//TTT5Odnc1LL72EWVkXfBWpWXRPWwlZpUPazLj//vtJT0+nVatWpKSkcOzYsaL1DRs2LNG/e/furF27lsOHD3PBBRf4pWaR6qQ9fAlZX3/9NStXrgTgzTff5JprrgGgefPm5OXlnfMsnIEDBzJu3DgGDRpEbm5utdcrUt20hy8hq3379syYMYNf/vKXtG3bll/96ld89913dO7cmZiYGLp3737ObQwbNozc3FwGDx7MwoULqV+/vh8qF6kedvoshWCUkJDg0tPTA12GhJkPdn3Ac+ue48DRA1zS8BLGxo9lUJtBgS5LpELMbK1zLqGsddrDFynmg10fkJKWwrFTnrn9/Uf3k5KWAqDQlxpPc/gixTy37rmisD/t2KljPLfuuXKeIVJzKPBFijlw9ECl2kVqEgW+SDGXNLykUu0iNYkCX6SYsfFjqRdZr0Rbvch6jI0fG6CKRHxHgS9nlZqaysyZM32yrZiYGL799lufbKu6DGoziJReKUQ3jMYwohtGk9IrRQdsJSToLB05q+Tk5ECX4HeD2gxSwEtI8skevpkNNLOtZrbDzMaVsb6JmS0wsw1mttnMRvliXDk/Q4cOpVu3bnTs2JEpU6YAEBUVxeOPP05cXByJiYl88803AKSkpDBx4kTAc0XJhx56iL59+9K+fXvWrFnDzTffTNu2bXniiSfOun0RCbwqB76ZRQIvADcAHYARZtahVLcHgC+cc3FAf+BvZlanqmPL+Zk2bRpr164lPT2dSZMmcejQIY4ePUpiYiIbNmygb9++vPzyy2U+t06dOixfvpzk5GSGDBnCCy+8wKZNm5g+fTqHDh0qd/siEni+2MO/GtjhnNvlnDsBzAaGlOrjgEbmuZpVFHAY0F0mAmTSpElFe/K7d+9m+/bt1KlTh5/+9KcAdOvWjczMzDKfO3jwYAA6d+5Mx44diY6Opm7durRp04bdu3eXu/2aqlevXoEuQcRnfDGH3xLYXWx5D9CjVJ/ngfeAfUAjYLhzrsy7SZjZaGA0wA9+8AMflCfFLVu2jMWLF7Ny5UoaNGhA//79OXbsGLVr1y66umRkZGS5d32qW7cuABEREUWPTy8XFBSUu/2aKi0tLdAliPiML/bwy7pQeOkL9AwAMoBLgS7A82bWuKyNOeemOOcSnHMJLVq08EF5UlxOTg7NmjWjQYMGbNmyhc8++6xGbd/foqKiAl1CtSp+jEZCny8Cfw/QqtjyZXj25IsbBcx1HjuAfwNX+WBsqaSBAwdSUFBAbGws48ePJzExsUZtX0TOX5WvlmlmtYBtQBKwF1gD3O6c21ysz4vAN865FDO7GFgHxDnnznpStq6WWXMdXX+QI4syOZV9nMimdWk8IIaGXS8KdFmVFhUVRV5eXqDL8KlnnnmGmTNn0qpVK1q0aEG3bt247rrrSE5O5vvvv+eKK65g2rRpNGvWjDVr1nDvvffSsGFDrrnmGj788EM2bdoU6JcgZ3G2q2VWeQ/fOVcAjAEWAV8CbznnNptZspmdPon7aaCXmX0OfAw8eq6wl5rr6PqDZM/dzqns4wCcyj5O9tztHF1/MMCVydq1a5k9ezbr169n7ty5rFmzBoCRI0fy17/+lY0bN9K5c2f++Mc/AjBq1ChSU1NZuXIlkZGRgSxdfMAn5+E75xY65650zl3hnHvG25bqnEv1Pt7nnLveOdfZOdfJOfe6L8YV/5g+fTpjxoypcP8jizJxJ0sek3cnCzmyKNPHlUllrVixgptuuokGDRrQuHFjBg8ezNGjR8nOzqZfv34A3HXXXSxfvpzs7Gxyc3OLzlS6/fbbA1m6+IAurSA+d3rPvqLt4l8VvSF7MN8cSc6PAj/Elfep2t/+9rfEx8eTlJREVlYW4Pkk7a9//Wt69epFp06dWL169Rnby8rK4pZbbqF79+50796dTz/99Iw+kU3rntF2tvagsfEteLYTpDT1fN/4VsjN3/ft25d3332X/Px8cnNzWbBgAQ0bNqRZs2asWLECgNdee41+/frRrFkzGjVqVHSm1ezZswNZuviAAj/Elfep2vj4eNatW0e/fv2K5msBjh49SlpaGpMnT+aee+45Y3tjx47loYceYs2aNcyZM4f77rvvjD6NB8RgtUv+aFntCBoPiPH56/OZjW/Bgv+CnN2A83xf8F+e9hASHx/P8OHD6dKlC7fccgt9+vQBYMaMGTzyyCPExsaSkZHBk08+CcDUqVMZPXo0PXv2xDlHkyZNAlm+VJEunhbiJk2axLvvvgtQ9KnXiIgIhg8fDsDPf/5zbr755qL+I0aMADx7gkeOHCE7O7vE9hYvXswXX3xRtHzkyBFyc3Np1KhRUdvps3Fq1Fk6Hz8FJ/NLtp3M97TH/iwwNVWTxx9/nMcff/yM9tKfmdi4cSNLly7l5ptvpkmTJmRmZpKQUObJH1JDKPBDWEU/9Vp8Trf0/G7p5cLCQlauXEn9+vXPOnbDrhcFd8CXlrOncu0hbuPGjSxYsID169fzySefUFhYSLNmzXjppZcCXZpUgaZ0Qlh5n3otLCzknXfeAeCNN97gmmuuKXrO//3f/wHwySef0KRJkzP+hL/++ut5/vnni5YzMjKq+VX4SZPLKtce4j7++GNOnjxJp06dSE5O5v7772fEiBGsX78+0KVJFSjwQ1h5n3pt2LAhmzdvplu3bixZsqRovhagWbNm9OrVi+TkZKZOnXrGNidNmkR6ejqxsbF06NCB1NRUv72eapX0JNQu9VdL7fqe9jATFRVFTk4Oubm5vPWW5xhGRkYGCxcuJCcnJ8DVSVVU+ZO21UmftK0e5X16tH///kycOLHcedovVyxlxeyZ5B76lkYXNqfPbSNp3+fa6i7Xfza+5Zmzz9nj2bNPejLk5u8rIioqiqeffrpEuGdkZLBv3z5GjBjBQw89FMDq5Fyq9ZO2Eh6+XLGUj6Y8T+63WeAcud9m8dGU5/lyxdJAl+Y7sT+DhzZBSrbnexiG/WlJSUnk5eUxefLkoraIiAiSkpL44IMP6NmzJ99++y0fffQRPXv2JD4+nmHDhoXcaayhRoEfhsr7pVy2bFm5e/crZs+k4ETJD04VnDjOitm+ud+tBJfY2FiSkpKKLqdQv359YmJi2LlzJ3/5y19YuHAhAH/6059YvHgx69atIyEhgb///e+BLFvOQWfpSIXkHir70kfltUvN1759ey688EJSUlKYPn06EyZMIDMzk48++ojGjRvz/vvv88UXX9C7d28ATpw4Qc+ePQNctZyNAl8qpNGFzT3TOWW0S3ho06YNu3btYtu2bSQkJOCc48c//jFvvvlmoEuTCtKUjlRIn9tGUqtOyUsj1KpTlz63jQxQReJvl19+OXPnzmXkyJFs3ryZxMREPv30U3bs2AHA999/z7Zt2wJcpZyNAl8qpH2fa7l+9BgaNW8BZjRq3oLrR48JrbN05JzatWvHrFmzGDZsGEeOHGH69OmMGDGC2NhYEhMT2bJlS6BLlLPQaZkiUmn7D8xn186JHDu+n3p1o2lzxcNEXzIk0GUJZz8tU3P4IlIp+w/MZ8uWxyks9Fx76NjxfWzZ4rk2j0I/uGlKR0QqZdfOiUVhf1phYT67dupm6MFOgS8ilXLs+P5KtUvwUOCLSKXUqxtdqfbzkZmZyRtvvOGz7YmHAl9EKqXNFQ8TEVHyQnMREfVpc8XDPhtDgV89FPgiUinRlwzhqqueoV7dSwHju8MX8Iv7shn/xAI6derEHXfcweLFi+nduzdt27Zl9erVHD16lHvuuYfu3bvTtWtX5s+fD3iCvU+fPsTHxxMfH09aWhoA48aNY8WKFXTp0oVnn302gK82xDjnqvwFDAS2AjuAceX06Q9kAJuBf1Vku926dXMiEtz+/e9/u8jISLdx40Z36tQpFx8f70aNGuUKCwvdvHnz3JAhQ9xjjz3mXnvtNeecc999951r27aty8vLc0ePHnX5+fnOOee2bdvmTv/OL1261A0aNChgr6kmA9JdOZla5dMyzSwSeAH4MbAHWGNm7znnvijWpykwGRjonPvazGrQrZBE5Fxat25N586dAejYsSNJSUmYGZ07dyYzM5M9e/bw3nvvMXGi50yeY8eO8fXXX3PppZcyZswYMjIyiIyM1Cd1q5kvzsO/GtjhnNsFYGazgSHAF8X63A7Mdc59DeCcO+iDcUUkSNSt+5/LbkRERBQtR0REUFBQQGRkJHPmzKFdu3YlnpeSksLFF1/Mhg0bKCwspF69en6tO9z4Yg6/JbC72PIeb1txVwLNzGyZma01s3IvwGJmo80s3czSs7LOvFiXiNQ8AwYM4B//+Mfp6d2iWyXm5OQQHR1NREQEr732GqdOnQKgUaNG5ObmBqxeX5g+fTpjxowJdBkl+CLwrYy20tdrqAV0AwYBA4DxZnZlWRtzzk1xziU45xJatGjhg/JEJNDGjx/PyZMniY2NpVOnTowfPx6A+++/nxkzZpCYmMi2bdto2LAh4Lkef61atYiLiwvYQVvnHIWFhQEZu7pU+Vo6ZtYTSHHODfAuPwbgnPtzsT7jgHrOuRTv8lTgn865t8+27bKupTNp0iRefPFFDhw4wKOPPsq4cePKfO706dNJT08vccNtEQle89bvZcKirezLzufSpvV5ZEA7hnYtPVlQvTIzM7nhhhu49tprWblyJUOHDuX999/n+PHj3HTTTfzxj38EYOjQoezevZtjx44xduxYRo8eDcCrr77Kn//8Z6Kjo7nyyiupW7eu3zOouq+lswZoa2atgb3AbXjm7IubDzxvZrWAOkAP4Lz+2548eTIffvghrVu3rkLJIhJM5q3fy2NzPyf/pGdKZ292Po/N/RzA76G/detWXn31VYYOHco777zD6tWrcc4xePBgli9fTt++fZk2bRoXXHAB+fn5dO/enVtuuYUTJ07whz/8gbVr19KkSROuvfZaunbt6tfaz6XKUzrOuQJgDLAI+BJ4yzm32cySzSzZ2+dL4J/ARmA18IpzblNlx0pOTmbXrl0MHjyYZ599tmh+7O2336ZTp07ExcXRt2/fov779u1j4MCBtG3blt/97ndVfakiUk0mLNpaFPan5Z88xYRFW/1ey+WXX05iYiIfffQRH330EV27diU+Pp4tW7awfft2wDPTEBcXR2JiIrt372b79u2sWrWK/v3706JFC+rUqcPw4cP9Xvu5+ORqmc65hcDCUm2ppZYnABOqMk5qair//Oc/Wbp0Ke+//35R+1NPPcWiRYto2bIl2dnZRe0ZGRmsX7+eunXr0q5dOx588EFatWpVlRJEpBrsy86vVHt1On0cwTnHY489xi9/+csS65ctW8bixYtZuXIlDRo0oH///hw7dgwAs7IOaQaPkPikbe/evbn77rt5+eWXi47yAyQlJdGkSRPq1atHhw4d+OqrrwJYpYiU59Km9SvV7g8DBgxg2rRp5OXlAbB3714OHjxITk4OzZo1o0GDBmzZsoXPPvsMgB49erBs2TIOHTrEyZMnefvtsx6iDIiQuB5+amoqq1at4oMPPqBLly5kZGQAJc8NjoyMpKCgIEAVisjZPDKgXYk5fID6tSN5ZEC7szyrel1//fV8+eWXRTdmj4qK4vXXX2fgwIGkpqYSGxtLu3btSExMBCA6OpqUlBR69uxJdHQ08fHxJXZAg0FIBP7OnTvp0aMHPXr0YMGCBezevfvcTxKRoHH6wGygz9KJiYlh06b/HF4cO3YsY8eOPaPfhx9+eEZbzoIFXPPa68yPrEWtyFpcdN11NLnxxmqtt7JCIvAfeeQRtm/fjnOOpKQk4uLiivbyRaRmGNq1pd8D3ldyFixg//gncd65/IJ9+9g//kmAoAp93dNWRKSKtv8oiYJ9+85or3XppbRd8rFfawm7e9rOOXCYP+/az97jJ2lZtzaPtYnmlksuCHRZIhKiCvaXfbev8toDJSTO0iluzoHDPLx1N3uOn8QBe46f5OGtu5lz4HCgSxOREFUruuy7fZXXHighF/h/3rWf/MKS01T5hY4/7wqu/2lFJHRc9NCvsVJX+rR69bjooV8HpqByhNyUzt7jJyvVLiJSVacPzB589n8p2L+fWtHRXPTQr4PqgC2EYOC3rFubPWWEe8u6tQNQjYiEiyY33hh0AV9ayE3pPNYmmvoRJT/eXD/CeKxNcM2liYj4W8jt4Z8+G0dn6YiIlBRygQ+e0FfAi4iUFHJTOiIiUjYFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGzcWfp1N7NTZnarL8YVEZGKq3Lgm1kk8AJwA9ABGGFmHcrp91dgUVXHFBGRyvPFHv7VwA7n3C7n3AlgNjCkjH4PAnOAgz4YU0REKskXgd8S2F1seY+3rYiZtQRuAlLPtTEzG21m6WaWnpWV5YPyREQEfBP4VkZb6Tuj/y/wqHPu1Lk25pyb4pxLcM4ltGjRwgfliYgI+OZqmXuAVsWWLwNK3749AZhtZgDNgZ+YWYFzbp4PxpcaIDU1lQYNGjBy5MhAlyIStnwR+GuAtmbWGtgL3AbcXryDc6716cdmNh14X2EfXpKTkwNdgkjYq/KUjnOuABiD5+ybL4G3nHObzSzZzPRbXgNlZmZy1VVXcd9999GpUyfuuOMOFi9eTO/evWnbti2rV6/m8OHDDB06lNjYWBITE9m4cSOFhYXExMSQnZ1dtK0f/vCHfPPNN6SkpDBx4kQAdu7cycCBA+nWrRt9+vRhy5YtAXqlIuHFJzdAcc4tBBaWaivzAK1z7m5fjCnVa8eOHbz99ttMmTKF7t2788Ybb/DJJ5/w3nvv8d///d+0atWKrl27Mm/ePJYsWcLIkSPJyMhgyJAhvPvuu4waNYpVq1YRExPDxRdfXGLbo0ePJjU1lbZt27Jq1Sruv/9+lixZEqBXKhI+QvKOV1J1rVu3pnPnzgB07NiRpKQkzIzOnTuTmZnJV199xZw5cwD40Y9+xKFDh8jJyWH48OE89dRTjBo1itmzZzN8+PAS283LyyMtLY1hw4YVtR0/ftx/L0wkjCnwpUx169YtehwREVG0HBERQUFBAbVqnfmjY2b07NmTHTt2kJWVxbx583jiiSdK9CksLKRp06ZkZGRUa/0iciZdS0fOS9++fZk1axYAy5Yto3nz5jRu3Bgz46abbuI3v/kN7du358ILLyzxvMaNG9O6dWvefvttAJxzbNiwwe/1i4QjBb6cl5SUFNLT04mNjWXcuHHMmDGjaN3w4cN5/fXXz5jOOW3WrFlMnTqVuLg4OnbsyPz58/1VtkhYM+dKf0YqeCQkJLj09PRAlyE+9MGuD3hu3XMcOHqASxpewtj4sQxqMyjQZYmEDDNb65xLKGud5vDFbz7Y9QEpaSkcO3UMgP1H95OSlgKg0BfxA03piN88t+65orA/7dipYzy37rkAVSQSXhT44jcHjh6oVLuI+JYCX/zmkoaXVKpdRHxLgS9+MzZ+LPUi65VoqxdZj7HxYwNUkUh40UFb8ZvTB2Z1lo5IYCjwxa8GtRmkgBcJEE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJnwS+mQ00s61mtsPMxpWx/g4z2+j9SjOzOF+MKyIiFVflwDezSOAF4AagAzDCzDqU6vZvoJ9zLhZ4GphS1XFFRKRyfLGHfzWwwzm3yzl3ApgNDCnewTmX5pz7zrv4GXCZD8YVEQlZy5YtIy0tzafb9EXgtwR2F1ve420rz73Ah+WtNLPRZpZuZulZWVk+KE9EpOYJ1sC3MtrKvDO6mV2LJ/AfLW9jzrkpzrkE51xCixYtfFCeiEjwmDlzJrGxscTFxXHnnXeyYMECevToQdeuXbnuuuv45ptvyMzMJDU1lWeffZYuXbqwYsUKn4zti8sj7wFaFVu+DNhXupOZxQKvADc45w75YFwRkRpl8+bNPPPMM3z66ac0b96cw4cPY2Z89tlnmBmvvPIK//M//8Pf/vY3kpOTiYqK4uGHH/bZ+L4I/DVAWzNrDewFbgNuL97BzH4AzAXudM5t88GYIiI1zpIlS7j11ltp3rw5ABdccAGff/45w4cPZ//+/Zw4cYLWrVtX2/hVntJxzhUAY4BFwJfAW865zWaWbGbJ3m5PAhcCk80sw8zSqzquiEhN45zDrOQs+IMPPsiYMWP4/PPPeemllzh27Fi1je+T8/Cdcwudc1c6565wzj3jbUt1zqV6H9/nnGvmnOvi/UrwxbgiIjVJUlISb731FocOeWa1Dx8+TE5ODi1bes5zmTFjRlHfRo0akZub69Px9UlbERE/6dixI48//jj9+vUjLi6O3/zmN6SkpDBs2DD69OlTNNUDcOONN/Luu+/69KCtOVfmCTVBISEhwaWna/ZHRMLDtlUHWDl/J3mHjxN1QV16DrmCK3tcUqltmNna8mZRdBNzEZEgsG3VAZbO2kLBiUIA8g4fZ+msLQCVDv3yaEpHRCQIrJy/syjsTys4UcjK+Tt9NoYCX0QkCOQdPl6p9vOhwBcRCQJRF9StVPv5UOCLiASBnkOuoFadkpFcq04EPYdc4bMxdNBWRCQInD4wW9WzdM5GgS8iEiSu7HGJTwO+NE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGxcGevNzCZ51280s3hfjCsiIhVX5cA3s0jgBeAGoAMwwsw6lOp2A9DW+zUaeLGq44qISOX4Yg//amCHc26Xc+4EMBsYUqrPEGCm8/gMaGpm0T4YW0REKsgXgd8S2F1seY+3rbJ9ADCz0WaWbmbpWVlZPihPRETAN4FvZbS58+jjaXRuinMuwTmX0KJFiyoXJyIiHr4I/D1Aq2LLlwH7zqOPiIhUI18E/hqgrZm1NrM6wG3Ae6X6vAeM9J6tkwjkOOf2+2BsERGpoCrf09Y5V2BmY4BFQCQwzTm32cySvetTgYXAT4AdwPfAqKqOKyIileOTm5g75xbiCfXibanFHjvgAV+MJSIi50eftBURCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMFGlwDezC8zs/5nZdu/3ZmX0aWVmS83sSzPbbGZjqzKmiIicn6ru4Y8DPnbOtQU+9i6XVgD81jnXHkgEHjCzDlUcV0REKqmqgT8EmOF9PAMYWrqDc26/c26d93Eu8CXQsorjiohIJVU18C92zu0HT7ADF52ts5nFAF2BVWfpM9rM0s0sPSsrq9IF9erVq9LPEREJB7XO1cHMFgOXlLHq8coMZGZRwBzg1865I+X1c85NAaYAJCQkuMqMAZCWllbZp4iIhIVzBr5z7rry1pnZN2YW7Zzbb2bRwMFy+tXGE/aznHNzz7vaCoiKiiIvL4/9+/czfPhwjhw5QkFBAS+++CJ9+vSpzqFFRIJaVad03gPu8j6+C5hfuoOZGTAV+NI59/cqjldhb7zxBgMGDCAjI4MNGzbQpUsXfw0tIhKUzrmHfw5/Ad4ys3uBr4FhAGZ2KfCKc+4nQG/gTuBzM8vwPu/3zrmFVRz7rLp3784999zDyZMnGTp0qAJfRMJelfbwnXOHnHNJzrm23u+Hve37vGGPc+4T55w552Kdc128X9Ua9gB9+/Zl+fLltGzZkjvvvJOZM2dW95AiIkEtZD9p+9VXX3HRRRfxi1/8gnvvvZd169YFuiQRkYCq6pRO0Fq2bBkTJkygdu3aREVFaQ9fRMJeyAV+Xto0eLYTd+Xs4a57L4OkxyD2Z4EuS0Qk4EIr8De+BQv+C07me5ZzdnuWQaEvImEvtObwP37qP2F/2sl8T7uISJgLrcDP2VO5dhGRMBJagd/kssq1i4iEkdAK/KQnoXb9km2163vaRUTCXGgFfuzP4MZJ0KQVYJ7vN07SAVsREULtLB3whLsCXkTkDKG1hy8iIuVS4IuIhAkFvohImFDgi4iECQW+iEiYMOcqfdtYvzGzLOArH22uOfCtj7blS6qrclRX5QRrXRC8tdX0ui53zrUoa0VQB74vmVm6cy4h0HWUproqR3VVTrDWBcFbWyjXpSkdEZEwocAXEQkT4RT4UwJdQDlUV+WorsoJ1rogeGsL2brCZg5fRCTchdMevohIWFPgi4iEiZAKfDMbaGZbzWyHmY0rY/0dZrbR+5VmZnFBVNsQb10ZZpZuZtcEQ13F+nU3s1Nmdmsw1GVm/c0sx/t+ZZiZX256UJH3y1tbhpltNrN/BUNdZvZIsfdqk/ff8oIgqKuJmS0wsw3e92tUdddUwbqamdm73t/J1WbWyU91TTOzg2a2qZz1ZmaTvHVvNLP4Sg3gnAuJLyAS2Am0AeoAG4AOpfr0App5H98ArAqi2qL4zzGVWGBLMNRVrN8SYCFwazDUBfQH3g/Cn7GmwBfAD7zLFwVDXaX63wgsCYa6gN8Df/U+bgEcBuoEQV0TgD94H18FfOynn7G+QDywqZz1PwE+BAxIrGyGhdIe/tXADufcLufcCWA2MKR4B+dcmnPuO+/iZ4C/7n1YkdrynPdfFGgI+ONo+jnr8noQmAMc9ENNlanL3ypS1+3AXOfc1wDOOX+8Z5V9v0YAbwZJXQ5oZGaGZ6fnMFAQBHV1AD4GcM5tAWLM7OJqrgvn3HI870F5hgAzncdnQFMzi67o9kMp8FsCu4st7/G2ledePP9T+kOFajOzm8xsC/ABcE8w1GVmLYGbgFQ/1FPhurx6eqcCPjSzjkFS15VAMzNbZmZrzWxkkNQFgJk1AAbi+Q88GOp6HmgP7AM+B8Y65wqDoK4NwM0AZnY1cDn+20E8m8rmXAmhFPhWRluZe8lmdi2ewH+0WisqNmQZbWfU5px71zl3FTAUeLq6i6Jidf0v8Khz7lT1l1OkInWtw3PNkDjgH8C86i6KitVVC+gGDAIGAOPN7MogqOu0G4FPnXNn24v0lYrUNQDIAC4FugDPm1nj6i2rQnX9Bc9/3Bl4/sJdT/X/5VERlfm3PkMo3eJwD9Cq2PJlePYaSjCzWOAV4Abn3KFgqu0059xyM7vCzJo756rzIk4VqSsBmO35i5vmwE/MrMA5Ny+QdTnnjhR7vNDMJgfJ+7UH+NY5dxQ4ambLgThgW4DrOu02/DOdAxWraxTwF+905g4z+zeeOfPVgazL+/M1CjwHSoF/e78CrVJZcgZ/HIjw08GOWsAuoDX/ORDTsVSfHwA7gF5BWNsP+c9B23hg7+nlQNZVqv90/HPQtiLv1yXF3q+rga+D4f3CMz3xsbdvA2AT0CnQdXn7NcEzP9ywuv8NK/F+vQikeB9f7P25bx4EdTXFe/AY+AWeefNqf8+848VQ/kHbQZQ8aLu6MtsOmT1851yBmY0BFuE5Cj/NObfZzJK961OBJ4ELgcnePdYC54er4lWwtluAkWZ2EsgHhjvvv3CA6/K7CtZ1K/ArMyvA837dFgzvl3PuSzP7J7ARKARecc6VeYqdP+vydr0J+Mh5/vqodhWs62lgupl9jifEHnXV+1daRetqD8w0s1N4zrq6tzprOs3M3sRzBlpzM9sD/AGoXayuhXjO1NkBfI/3r5AKb7+af0dERCRIhNJBWxEROQsFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhIn/DyCri3Zc6/JlAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "打印\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.autograd import variable\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from tqdm import tqdm\n", + "\n", + "dtype = torch.FloatTensor\n", + "#我们使用的语料库 \n", + "sentences = ['i like dog','i like cat','i like animal','dog is animal','cat is animal',\n", + " 'dog like meat','cat like meat','cat like fish','dog like meat','i like apple',\n", + " 'i hate apple','i like movie','i like read','dog like bark','dog like cat']\n", + "\n", + "\n", + "\n", + "word_sequence = ' '.join(sentences).split() #将语料库的每一句话的每一个词转化为列表 \n", + "#print(word_sequence)\n", + "\n", + "word_list = list(set(word_sequence)) #构建我们的词表 \n", + "#print(word_list)\n", + "\n", + "#word_voc = list(set(word_sequence)) \n", + "\n", + "#接下来对此表中的每一个词编号 这就用到了我们之前提到的one-hot编码 \n", + "\n", + "#词典 词对应着编号\n", + "word_dict = {w:i for i,w in enumerate(word_list)}\n", + "#print(word_dict)\n", + "#编号对应着词\n", + "index_dict = {i:w for w,i in enumerate(word_list)}\n", + "#print(index_dict)\n", + "\n", + "\n", + "batch_size = 2\n", + "voc_size = len(word_list)\n", + "\n", + "skip_grams = []\n", + "for i in range(1,len(word_sequence)-1,3):\n", + " target = word_dict[word_sequence[i]] #当前词对应的id\n", + " context = [word_dict[word_sequence[i-1]],word_dict[word_sequence[i+1]]] #两个上下文词对应的id\n", + "\n", + " for w in context:\n", + " skip_grams.append([target,w])\n", + "\n", + "embedding_size = 10 \n", + "\n", + "\n", + "class Word2Vec(nn.Module):\n", + " def __init__(self):\n", + " super(Word2Vec,self).__init__()\n", + " self.W1 = nn.Parameter(torch.rand(len(word_dict),embedding_size)).type(dtype) \n", + " #将词的one-hot编码对应到词向量中\n", + " self.W2 = nn.Parameter(torch.rand(embedding_size,voc_size)).type(dtype)\n", + " #将词向量 转化为 输出 \n", + " def forward(self,x):\n", + " hidden_layer = torch.matmul(x,self.W1)\n", + " output_layer = torch.matmul(hidden_layer,self.W2)\n", + " return output_layer\n", + "\n", + "\n", + "model = Word2Vec()\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(),lr=1e-5)\n", + "\n", + "#print(len(skip_grams))\n", + "#训练函数\n", + "\n", + "def random_batch(data,size):\n", + " random_inputs = []\n", + " random_labels = []\n", + " random_index = np.random.choice(range(len(data)),size,replace=False)\n", + " \n", + " for i in random_index:\n", + " random_inputs.append(np.eye(voc_size)[data[i][0]]) #从一个单位矩阵生成one-hot表示\n", + " random_labels.append(data[i][1])\n", + " \n", + " return random_inputs,random_labels\n", + "\n", + "for epoch in tqdm(range(100000)):\n", + " input_batch,target_batch = random_batch(skip_grams,batch_size) # X -> y\n", + " input_batch = torch.Tensor(input_batch)\n", + " target_batch = torch.LongTensor(target_batch)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(input_batch)\n", + "\n", + " loss = criterion(output,target_batch)\n", + " if((epoch+1)%10000==0):\n", + " print(\"epoch:\",\"%04d\" %(epoch+1),'cost =' ,'{:.6f}'.format(loss))\n", + "\n", + " loss.backward() \n", + " optimizer.step()\n", + "\n", + "for i , label in enumerate(word_list):\n", + " W1,_ = model.parameters()\n", + " x,y = float(W1[i][0]),float(W1[i][1])\n", + " plt.scatter(x,y)\n", + " plt.annotate(label,xy=(x,y),xytext=(5,2),textcoords='offset points',ha='right',va='bottom')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1edccf25", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9 (default, Aug 31 2020, 12:42:55) \n[GCC 7.3.0]" + }, + "vscode": { + "interpreter": { + "hash": "7648c2b9d25760d0d65f53f9b9a34de48caa24d8265d64b0ff81e2f2641d528d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git "a/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" new file mode 100644 index 000000000..51843e99d --- /dev/null +++ "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" @@ -0,0 +1,434 @@ +# 词嵌入(概念部分) + +###   在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入?
+ 计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。
+ 我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1]
+ 例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
+ +
一个小朋友刚学会了四个字和词-->[我] [特别] [喜欢] [学习]
+ 我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码
+ 对于中文 我们先进行分词 我 特别 喜欢 学习
+ 那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1]
+ 现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1]

+ 我们可以思考几个问题:
+ 1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。
+ 2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。
+ 因此我们认为独热编码不是一个很好的词嵌入方法。
+ + 我们再来介绍一下 **稠密表示**
+ + 稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97] + +# Bag of Words词袋表示 + +  词袋表示顾名思义,我们往一个袋子中装入我们的词汇,构成一个词袋,当我们想表达的时候,我们将其取出,构建词袋的方法可以有如下形式。 + + +```python +corpus = ["i like reading", "i love drinking", "i hate playing", "i do nlp"]#我们的语料库 +word_list = ' '.join(corpus).split() +word_list = list(sorted(set(word_list))) +word_dict = {w: i for i, w in enumerate(word_list)} +number_dict = {i: w for i, w in enumerate(word_list)} +``` + + +```python +word_dict +``` + + + + + {'do': 0, + 'drinking': 1, + 'hate': 2, + 'i': 3, + 'like': 4, + 'love': 5, + 'nlp': 6, + 'playing': 7, + 'reading': 8} + + + + +```python +number_dict +``` + + + + + {0: 'do', + 1: 'drinking', + 2: 'hate', + 3: 'i', + 4: 'like', + 5: 'love', + 6: 'nlp', + 7: 'playing', + 8: 'reading'} + + + + 根据如上形式,我们可以构建一个维度为9的one-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用) + + +```python +import numpy as np +voc_size = len(word_dict) +bow = [] +for i,name in enumerate(word_dict): + bow.append(np.eye(voc_size)[word_dict[name]]) +``` + + +```python +bow +``` + + + + + [array([1., 0., 0., 0., 0., 0., 0., 0., 0.]), + array([0., 1., 0., 0., 0., 0., 0., 0., 0.]), + array([0., 0., 1., 0., 0., 0., 0., 0., 0.]), + array([0., 0., 0., 1., 0., 0., 0., 0., 0.]), + array([0., 0., 0., 0., 1., 0., 0., 0., 0.]), + array([0., 0., 0., 0., 0., 1., 0., 0., 0.]), + array([0., 0., 0., 0., 0., 0., 1., 0., 0.]), + array([0., 0., 0., 0., 0., 0., 0., 1., 0.]), + array([0., 0., 0., 0., 0., 0., 0., 0., 1.])] + + + +# N-gram:基于统计的语言模型 + N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。 + + N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 + +**N-gram 模型的优点:** + +简单易用,N-gram 模型的概念非常简单,实现起来也很容易。 +能够捕捉到语言中的相关性,N-gram 模型通过考虑连续出现的 n 个词语来预测下一个词语,因此它能够捕捉到语言中词语之间的相关性。 +可以使用已有的语料库进行训练,N-gram 模型可以使用已有的大量语料库进行训练,例如 Google 的 N-gram 数据库,这样可以大大提高模型的准确性。 + +**N-gram 模型的缺点:** + +对于短文本数据集不适用,N-gram 模型需要大量的文本数据进行训练,因此对于短文本数据集可能无法达到较高的准确性。 +容易受到噪声和语言纠错的影响,N-gram 模型是基于语料库进行训练的,如果语料库中包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。 +无法捕捉到语言中的非线性关系,N-gram 模型假设语言中的关系是线性的,但事实上语言中可能存在复杂的非线性关系,N-gram 模型无法捕捉到这些关系。 + +# NNLM:前馈神经网络语言模型 + 下面通过前馈神经网络模型来**展示滑动**窗口的使用 + + +```python +#导入必要的库 +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.autograd import Variable +dtype = torch.FloatTensor +``` + + +```python +corpus = ["i like reading", "i love drinking", "i hate playing", "i do nlp"] + +word_list = ' '.join(corpus).split() +word_list +``` + + + + + ['i', + 'like', + 'reading', + 'i', + 'love', + 'drinking', + 'i', + 'hate', + 'playing', + 'i', + 'do', + 'nlp'] + + + + +```python +#构建我们需要的语料库 +corpus = ["i like studying", "i love datawhale", "i hate playing", "i do nlp"] + +word_list = ' '.join(corpus).split() #将语料库转化为一个个单词 ,如['i', 'like', 'reading', 'i', ...,'nlp'] +word_list = list(sorted(set(word_list))) #用set去重后转化为链表 +# print(word_list) + +word_dict = {w: i for i, w in enumerate(word_list)} #将词表转化为字典 这边是词对应到index +number_dict = {i: w for i, w in enumerate(word_list)}#这边是index对应到词 +# print(word_dict) +# print(number_dict) + +n_class = len(word_dict) #计算出我们词表的大小,用于后面词向量的构建 + +m = 2 #词嵌入维度 +n_step = 2 #滑动窗口的大小 +n_hidden = 2 #隐藏层的维度为2 + + +def make_batch(sentence): #由于语料库较小,我们象征性将训练集按照批次处理 + input_batch = [] + target_batch = [] + + for sen in sentence: + word = sen.split() + input = [word_dict[n] for n in word[:-1]] + target = word_dict[word[-1]] + + input_batch.append(input) + target_batch.append(target) + + return input_batch, target_batch + + +class NNLM(nn.Module): #搭建一个NNLM语言模型 + def __init__(self): + super(NNLM, self).__init__() + self.embed = nn.Embedding(n_class, m) + self.W = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype)) + self.d = nn.Parameter(torch.randn(n_hidden).type(dtype)) + + self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype)) + self.b = nn.Parameter(torch.randn(n_class).type(dtype)) + + def forward(self, x): + x = self.embed(x) # 4 x 2 x 2 + x = x.view(-1, n_step * m) + tanh = torch.tanh(self.d + torch.mm(x, self.W)) # 4 x 2 + output = self.b + torch.mm(tanh, self.U) + return output + +model = NNLM() + +criterion = nn.CrossEntropyLoss() #损失函数的设置 +optimizer = optim.Adam(model.parameters(), lr=0.001) #优化器的设置 + +input_batch, target_batch = make_batch(corpus) #训练集和标签值 +input_batch = Variable(torch.LongTensor(input_batch)) +target_batch = Variable(torch.LongTensor(target_batch)) + +for epoch in range(10000): #训练过程 + optimizer.zero_grad() + + output = model(input_batch) # input: 4 x 2 + + loss = criterion(output, target_batch) + + if (epoch + 1) % 1000 == 0: + print('epoch:', '%04d' % (epoch + 1), 'cost = {:.6f}'.format(loss.item())) + + loss.backward() + optimizer.step() + +predict = model(input_batch).data.max(1, keepdim=True)[1]#模型预测过程 + +print([sen.split()[:2] for sen in corpus], '->', [number_dict[n.item()] for n in predict.squeeze()]) +``` + + epoch: 1000 cost = 1.010682 + epoch: 2000 cost = 0.695155 + epoch: 3000 cost = 0.597085 + epoch: 4000 cost = 0.531892 + epoch: 5000 cost = 0.376044 + epoch: 6000 cost = 0.118038 + epoch: 7000 cost = 0.077081 + epoch: 8000 cost = 0.053636 + epoch: 9000 cost = 0.038089 + epoch: 10000 cost = 0.027224 + [['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'do']] -> ['studying', 'datawhale', 'playing', 'nlp'] + + +# Word2Vec模型:主要采用Skip-gram和Cbow两种模式 + 前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。 + skip-gram模型(跳字模型)是用中心词去预测周围词 + cbow模型(连续词袋模型)是用周围词预测中心词 + + +```python +import torch.nn as nn +import torch.optim as optim +from torch.autograd import variable +import numpy as np +import torch +import matplotlib.pyplot as plt +from tqdm import tqdm + +dtype = torch.FloatTensor +#我们使用的语料库 +sentences = ['i like dog','i like cat','i like animal','dog is animal','cat is animal', + 'dog like meat','cat like meat','cat like fish','dog like meat','i like apple', + 'i hate apple','i like movie','i like read','dog like bark','dog like cat'] + + + +word_sequence = ' '.join(sentences).split() #将语料库的每一句话的每一个词转化为列表 +#print(word_sequence) + +word_list = list(set(word_sequence)) #构建我们的词表 +#print(word_list) + +#word_voc = list(set(word_sequence)) + +#接下来对此表中的每一个词编号 这就用到了我们之前提到的one-hot编码 + +#词典 词对应着编号 +word_dict = {w:i for i,w in enumerate(word_list)} +#print(word_dict) +#编号对应着词 +index_dict = {i:w for w,i in enumerate(word_list)} +#print(index_dict) + + +batch_size = 2 +voc_size = len(word_list) + +skip_grams = [] +for i in range(1,len(word_sequence)-1,3): + target = word_dict[word_sequence[i]] #当前词对应的id + context = [word_dict[word_sequence[i-1]],word_dict[word_sequence[i+1]]] #两个上下文词对应的id + + for w in context: + skip_grams.append([target,w]) + +embedding_size = 10 + + +class Word2Vec(nn.Module): + def __init__(self): + super(Word2Vec,self).__init__() + self.W1 = nn.Parameter(torch.rand(len(word_dict),embedding_size)).type(dtype) + #将词的one-hot编码对应到词向量中 + self.W2 = nn.Parameter(torch.rand(embedding_size,voc_size)).type(dtype) + #将词向量 转化为 输出 + def forward(self,x): + hidden_layer = torch.matmul(x,self.W1) + output_layer = torch.matmul(hidden_layer,self.W2) + return output_layer + + +model = Word2Vec() +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(),lr=1e-5) + +#print(len(skip_grams)) +#训练函数 + +def random_batch(data,size): + random_inputs = [] + random_labels = [] + random_index = np.random.choice(range(len(data)),size,replace=False) + + for i in random_index: + random_inputs.append(np.eye(voc_size)[data[i][0]]) #从一个单位矩阵生成one-hot表示 + random_labels.append(data[i][1]) + + return random_inputs,random_labels + +for epoch in tqdm(range(100000)): + input_batch,target_batch = random_batch(skip_grams,batch_size) # X -> y + input_batch = torch.Tensor(input_batch) + target_batch = torch.LongTensor(target_batch) + + optimizer.zero_grad() + + output = model(input_batch) + + loss = criterion(output,target_batch) + if((epoch+1)%10000==0): + print("epoch:","%04d" %(epoch+1),'cost =' ,'{:.6f}'.format(loss)) + + loss.backward() + optimizer.step() + +for i , label in enumerate(word_list): + W1,_ = model.parameters() + x,y = float(W1[i][0]),float(W1[i][1]) + plt.scatter(x,y) + plt.annotate(label,xy=(x,y),xytext=(5,2),textcoords='offset points',ha='right',va='bottom') +plt.show() +``` + + 11%|█ | 10615/100000 [00:02<00:24, 3657.80it/s] + + epoch: 10000 cost = 1.955088 + + + 21%|██ | 20729/100000 [00:05<00:21, 3758.47it/s] + + epoch: 20000 cost = 1.673096 + + + 30%|███ | 30438/100000 [00:08<00:18, 3710.13it/s] + + epoch: 30000 cost = 2.247422 + + + 41%|████ | 40638/100000 [00:11<00:15, 3767.87it/s] + + epoch: 40000 cost = 2.289902 + + + 50%|█████ | 50486/100000 [00:13<00:13, 3713.98it/s] + + epoch: 50000 cost = 2.396217 + + + 61%|██████ | 60572/100000 [00:16<00:11, 3450.47it/s] + + epoch: 60000 cost = 1.539688 + + + 71%|███████ | 70638/100000 [00:19<00:07, 3809.11it/s] + + epoch: 70000 cost = 1.638879 + + + 80%|████████ | 80403/100000 [00:21<00:05, 3740.33it/s] + + epoch: 80000 cost = 2.279797 + + + 90%|█████████ | 90480/100000 [00:24<00:02, 3680.03it/s] + + epoch: 90000 cost = 1.992100 + + + 100%|██████████| 100000/100000 [00:27<00:00, 3677.35it/s] + + + epoch: 100000 cost = 1.307715 + + + + +![png](output_16_20.png) + + + + 打印 + + + +
+ + + +```python + +``` diff --git "a/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/output_16_20.png" "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/output_16_20.png" new file mode 100644 index 0000000000000000000000000000000000000000..56fa3da4e714085b37e5b0637ef20ecc8ad6edcd GIT binary patch literal 10308 zcmcI~cT`hdx_3eqK@bo@P*4Psj)Y<;0W=_pf>NXxMMAHk_lQarL~5jpfFez#NfCmA zQbIs_6Ob0^9YVi5Gw(NZ@65Mmt^3Dit>9UboO81G^ZeR#BJOG3I(3rsBm@FErJ;`0 zfk2=E;CSx@4fvhydr%AhfqSSLdg!{KJiH&eTSM+X^l)`_@o=PBb?gr1fBfTJ)g9t zU;aLNeP52(c4U8ReHMJroDYp^5yhP53o&&pw9F4-kF-_2uc_O7)2ZSSTdKXH6H_^N z{v)J{^8J{TPp+N~K5<6s1l5TXZC~yevot=X`o{nG@F?Zd-5V5iSo|FnqJ-EQKjcs{ zv><~55yIzCiO#&wl@ zEe*S*qea_{4cf<>_0;a%VVIkn!-(Dda1jQJ=W_mxDK@Kpm6W8`o+OG@3PzMP9LeeC z-TB=V!FsS7rZinKV1D`X<(%@#&&$YkV|@=N6xU*czrr8w@*k%s%%fWo$f3i`>jKE9RFI{VesQa%B7JioB7`S z&H28l#!y<=!>?45l9Hb%BZWB>{iKdbKa>n>y+m0Bw5E^tSC}^2ne$7Z_%T`BxVg|@ zG}oO=5BXK&c^TaP-Bg*iCRXj2>FIaP&ABOQ7_LwGf}fn++-r`5AG0f`e#GtYY&v;y(W&a!_osBMe}*fB<-M25XStIum-zYfbvduamV@05O2Z2E zjn)keMueCAa zDs~oLLAk@K)1zldED2&x#tk}4yzb5?1L$tYGT7ca9FF53o*#&rj)N{y=*6jbybG;l zNzmd+=`)h!tp_WZC0*{INnIDtwEOt>4}2bWsByj6-C_6eFnO14D zWJ1tlY=$Qes0N_{OwsGjpLuz8+z@mD#L;kcp*N1Qoc%+1QW*u@{NM^9bWWm?2!64qBRAq`HQN9O22ACShQg zj7+MWt-%KzM?c4b0)+XlKmHzl?i;zI%I#QpcX#WKkM>m$%gV}j2O5r=U4r*VvMa|A zm%Yy`_`Ec*2uut=aag0?k(X&nXH;Zdk?p-Q201t&DiPzg5f4FEfm4SP(&ft@x3b8% zjG6WZlU9H8XB=<#Xve3gcW8=#d1_?D9>A7w4%%pFKq+uv&C7nmF)12Z&Iy*a83<&G|Di0ICR1mnojZe82YH_rlT<$cIWN|oe z5CE3OV|f&tlEP8ry%I!lV&IaA9vwBMVdBPreMX;e(NNDL#GcRXwI}k-GU4URiSC{r zO0(xru}ofpIP^ls6z(UbY4JPU~G93$Q80=x;3Fd-G+34>Q*<5Pz-NLH0j{JOoL@vM!rl+TslPA1GhZ%Hmn(FJV z-1XQ~Qe{P`%5V7ZiKX=sSW$_xQ>e-zyxQ&~Q@GO2FveU^G#-nCae}&ekkf2z%~P>T zTBfGwBO)RKHacV>VA~|zCXqO_YHJv`UwD1ses_JKKR_umSl=0`bn=~y5Fpe9GSk%w z(bi*CDa0YNkP9KqsQx#4RkR`EX=pd);0YJhgLfaS+IO)sPMi({#r%qj=^F7@-2QZ3 zis4qg{pb-dmyBbU%XmGRDy$JzTIq5k7WHY<(~qAsocl^IUWd!On@YyvFvbf)-ut{P zKJ|m&8aM}kr;W^{$wTop&lBdSLeyrfU+bT7<>)fILh-meM?=+rz1DjLwZFa8o39&2 z<|l~9LjMaq9xs_G@3U$Nk@H@@^CRX$L`+Pc^RJq6yPoI2YQ2+d4_|*UC^5%~sWjOH zTnig#88mbBW$JpXmr2Czgaj&Iw)~7P#@A}34B1*r;h2I}DR~kEnng`2qno4H%iL#f zb4uANA0HjUy%sqt90tT{Y9521YVW`gCRJh(W@cv9SZ#6kGiMUOg@i;!MF9|yqQ$pY z+iULm%+YT$KqNynF=Dt4RR8C~7jRbMXTM&6sHJg>iRIaZtb%@9mPeB_vDrj38pmIwRFtS0WAa= zTG`Q21Tde#wQHqnP8LXlMPf;k9jYS4Ntxv%yCLQ1a@vS#VP_OhmQZ~LR?A^X+XN>P z{B$adlZ?<8%z|HDj*^5Y2Bx*-dDn+)IQSeQ-s z2Yq4Va$6O^tDgbHscUJ??yocuBZZ4Chs*7x9QvcB9r|5YzU*{=e0a0CdX^J$uzFL* zY52^aT|%Sy%uG6pax_!WOle!cK# zFZ6NeFGBv}6eO(;p|w4_blxmj0+-l2T;YIK^p{x}D6v@{tu8bwqX~(-d}|$`U*Zcc z+3sqz%@L3i6cYBmS9NrCjjG-G0sUh{%mQ!%#NYD7vvR~@4}#h6H2?1-N%Lu%TZl1Y zTJ0EoUbM<|VINFicH7ZKimo%PvE zlNM}WF8NE0pVuJ7QYbYg?&J@TA8l7P{rC}|kbr#f;OxXM-%?5Np@ZMnyy`oxbi*%x zTjxgF+uA0ROq|~Ul=WB{OV?I9+>A!V-CfA|=~cO%Rk|kXQ^}emwG}W^&Zsd~gr|?!)M#B@|$Ej!_w^(6E#uWWe12MbP2i&q5;@r?kDEM_%9JP}$$^@~`_L zk$n1#Ol~y<9Y)SNK*>eJmQHXog|F={Y+zrwG(I_glJ+ZPBfw(SB4hhB3lC4{>hEu4Iaf96P9M0p3AF@# zrm)^}x#|1&S8;KTfS2fSpt;WhO{nO<-2}fCf8{QPQN3QASjO4OD|cTv)vnaj-`mU| z2oFE;oQa1K*BN|-cCO!@%|5!U77bDGe{w74{3Ut$+NSNZ;^ugEi-7y(0n;taU>WRs z-`~GKGFRqi-)khtqeLyRp;?)Z-zoobBqAZLAOC|VI1RiGqPM#&CI5{ zy&*O%nZ>FE9h{!Zh63~ek zvSSl~9tHgf$A{y`&{gn}t{jbc*wNF5qrC;+qg`TD5=>I^!>&F_Bfec|fO`^gM6%M# zP|yHHLheMMGi;i|nF}o%5M8+%Y~&jS=XU~16&Qy+Knfy8pK@&L&t`!-Aggcy#(7%l ze&1yT?!HF)`N@@*mU{f@!u0m)HGIjiPu%3|727helz zif3^1)21J`;wxFW>rG1&V)GQVm#d5>&R374)_x`scio%2m6K=XOO-;yWK?cf68$-z zAKGbXXaIe2&3An|yAqVit5>f`6QN9hhROuSgZ4#v0ymjKY4ITt84tfbe>Q4Y_(l!G z^So_8*R;~nsM>!=7}c3Gfox^cVoSSDh#ab~9xtLw9vn0T=ZO@1cdsER5SXyRYP3M` zZo7ryXV*)R;Juzq!#ba9z^YunbP3aNe7J%UF(Ihf%&9+{<`&xTzBN;!*)afUR;+II zhK`O78Ddzk+#%bsSj1j9kainj_cof@ExKyGT_9R%*Ht<*6e&$B(q2$l2<7~I)}hz~ zw=UQ8vF*t_RqM6%2+%YiAKygU)DNYjYcBP>s4UI+lidk1dg_@YYB~j4t(up6Kj_;I&4!D5^=bx>$Nxt#qEp-cUC%$P?+Za{aS&U4%I=M z9GucFP?VJY*pr1@v2;ac9@A+6d4B-B<&bgY04?@#Z*wl1i=hO zEzogjvRWqghSF!d+qaNH*5Jdc4}*;#K19@d&7=M_Y5|L!=)Oogx?Q2UdT%01J-4@j z%PfC;BMysg<#MU}UA7bvMqz;z%I#odVr9O+egnl;wg2hdaTA4{=X}!z@6pD|=BT+}wX#klmDXGHediVZ68x8|!^wXxYA-lvPSVr=s`K3# zeR=8Ts8y`uthxV^G%*c59=Mb4Fj866-rkPB70c}j1U~^z1iLk>bxaU9tLZ((DRtGd zB>xI!w;3F3JkulDA;RJVC0+YSYiNa9ffMVm_P-O9ILGlPZGRYV-&OPyIc;iO<$RY& zNQeFzEWHdmCFNYYy!RDQh@dA{$NX2Wxn&dYWB`a?+ZwPy1F(`r;W>a)508%S74R&; ziWC0QY1EdeGR!3$d-L8XBi@q`cxoa|{!gfzDxZ39(SByTGZKkBBYF#oya&|DLk}#06hN$4+sqA=6+?H@Xt<9pT2P6E)q#q@@1P7Kp{oe z{kLkt#%hTB!FA(wKMXv5Ut!N%4?hv4HGS9(KFH=%lz~b8iM4hPS4hDuQJSvKQP{o9J z9s~){bG0>kyT)>;epYoi!l_j;C7Ow}eRL^eD#Q<{u9bDYiA+LuRcGbWr{Y|gT{oY3 zz?v%K`Rlm1v52d zbngprq1C^ibA8?(b|C8;p!jDFcgBgNwW(2$p1TJig`l$h)AeR?>`N-x#%o@ShHA0g zr*H)&^=+BQ2R7Mn-&U`;3Fn(s6_3N-*OcUz(bzhmiZEYrPCm1XQK?J{E?Qvw{2X!5V0MH=W#1<8ax*smH%TX0R|{)AnjR;zib=V zcrDEj40qBVhDL~*6PY+QOX<}}w6m`DkK0;WsDWZL z25yBE(0-M+L5oXO<9T4unwpz`R=JF49|5CaHG#vBEhO0#>j&oZ>fF5|GHiTi``t{Y zLNf?m4$a$Vk}b4=VIlJZ`9&^LzzsnDe+=CB!ohk@v$NjqV80Lm2kPAw&a z#tU*ypf7y)m&&-a8Ny$F<_-eiu;4+_WGE9N3y3A)i6H!)r(&X`b3huR%of=)q#&7P z76A_k*{iq=3b?jRp}@~9_;4!#+@7i|DmF555`c^i*rJ@gJQgl4O+anxckWPKl#_d% zkZ{Itdr<>KacgU9{9ekTD(X~6Agh@B`Ic7EZ_~2z2{owj{ZVa}xGc9loMTpJH*mNQ zd-9IYn}bCW0j00_c07zH$VbhD@PyiQV|W1{Tc}#pVr=U1%Tua)?Yn(Z)u$W0t31Bb z%#|+VG3c0_96pc^MsBaG=Abx+D{{vv#|+%hidY_#1l-1|51s_>4=NpLgk+rz{MGXS znG;!gd^H06C&9MFb1t{sY`xUNfv(nnm;iT8QCb=99GeobGpdq_t%#US(iA)4VZ_Ul zeR4ZFu&QKzn3A85&@Jdpz&fYO&}E*DOuexYPQ=HH)=@i38kM zFR_z?$0{-pzP{?w+kFdrbx9y%*t6qprw~ad<#2 zj9c!`DfC}?y~oV2G-j&w)u-^TYe#NGILV}C>31fA>z2)2czD7YcMMZTmjNz{e5fl2 zqBI%m`f&9WK+NG!>Hfpvs>Vd>kTPt<9wm==95veWlr7Umc@+T9#T*w|;69OR#J5V} zO4$YFc!;{tHn(G7d0+p>8RodQjb`qnJ{%))RE#)EXDK6*6CK1vfer~paOfhDLapqW zzgzm)s-kdg1(h>}PG`k67?%@j4waOVm^cGq9Z40yN_wDXGAIA`SbX(8dbzh+X7-=q z;?J+NGZEn@NVE1>|BwgGN#c~Z$LkfLDNZOz2PQMc2!nRcVEUfSc~M$A5jZ~y_vsep zWL}oa0A`XyNF}PDbC7VpOA2{Gtm2uCP! zF;yH>m4;3QXolZ3F{iXfq1^lQ4cq|H|IF7FSpGc42wck@(8YEmm7)a(C9?ou7%s>q z{3Wx1D*;T?)YJsxZ?p0uo}<+;a;EcFqyzC<9&N>8DChc%xPZ5M2r_P{UoP+OW&v4c zH-aeD*(7`64M-7s+1Vmmu+ISmC3V1_-oZ*f11e|ZJC_UDbC5N_+BmQv<@SA%U``+j zGV>%6qpvM-ms?v}EL&b)ssjKM%_$uTB0y$V)^C8zO&^WCy%l=}EDoT@PxgI6%K-1d zhgh-RH8`SjXW>`yW>qC%2 zDs1)Yk}on;W&;Z#t+I@Cw6|Xt6@39DJ6pvCrJ&JZ@jHCcf~&cAJQw;!{mwHp-!BKh zR8&-;753@7Z>><9i$n2XJn}Lzu^C7!9}sHY-V_9a!>%U}AeElOA+P%tGtFBBO)@=6 zJDi1p<$4RENh#-14%k#)s`I35!|o|CwIG7LN@T6>-v$7UD_c=dFH_g2B?GJiX(Lb! z)}x=@Kydli_7~VU-@~ne&J+UN>E9**hx(Q3hA`soU^O8hAmcx8Zf+_kUqWm)7Bl%c z|3aU??iZrTY5sMPR0@*8t6&iDl;NL)fMyrO;1*QSPVl_ECqu|Tl1lb9%DZ#n7ZPm?*o|UAYDo#^ z06sreOx99+Xi=U?G3u6ASDQ;{D$eDEu+yCem-x>;{%;_VN%cPh%OghE$e+X8u(Ioa zpB)&?;AG3xC#WZAvM43K;^h+Nw_UO_(PC$PI8$``$YTWefYPvF>5t#wRqb$Zij-3f zf@6*FE0zhM2b~evJ5U`R??yINC#tBt*iFW!oP;~S;KF8#r2aWwKwQ;E z=FHOdND?!v`S_^$WM=eIRQu%>s4<&KJ;R_uO+2=l zq68v)oyfRrY04bOX}vMqp|F961?;NZ+QAg@fj?)C)aI0hpm{furm0vp6CL=~5eGa*!WppqO>tz*vG6zml7 z9U|ZG-34;R;p%s%(!Lc=i;w`Jx1?ONA{}gW`T_@V9Z1NE7c21_tLbR^bhH7;&|C?M zQ%;(mgj_$KablXoY}&%);q6uOTt!+#M~B7b%huU2Mqg9~6A$7JkQAaOegri3nw8r9 zabS1Ah;5-~ZA-(1nOSKLxR6+-qw63)xhy6YGk!4pH#Bn`lmG@(SE`TFb?SxeI!KuI zHnWQ02D5b1Dwi%@YRrA-uzd^p^;@Mrtsy*Ez&9+G)lL6?&3$MK4=nQ_(o&I4ahnpkIh`6C?t{baNQPoO}-O_X$@sya3*~HeoNb~e4xy8 z;3N4?OkY5l1AqAj$IeVgai6`nNN<0aT-sEajGM(-57J-7nHFeKDTbN;F1zN@^#FV# z4IyFDSLQ)bBad_WiWl7?v?=SQ{`dXzul)HxNYnJ(LI?he?|{ocW^e!AKVFnw{rl3K zKBXU~BE!{ZNROC~?JdL5&dKOyDhVzQoe!O-E%P==47T>Dp9l^UK%g!EjX*yILjK!0 zj@&RvoGix0y1aVh&y<~qb0k^}bYB;cu|n(uyu93?J{6S<96Vum;+xwP|%*^>&LgPLCa05#eC`e^KiH^T1{4Wc- zm}fZckPEmS-1`qPV6d1t6`z^{GTTYItx2c+W^VZjs6{kCFE8;{W;tkj){_CuaT0&q zqW{~&`?&z4=<$ok*dec#9hCK^oi$3cV^I?};)t>;kjx7YPs{wU{5zHs;YFyRMr?~E zKY$$c!w0z4skkM|81>H^fJM&#fKj4K`5mZ0$zfa*igs>5JE!lsiLLT?eTpP+6u-Xo zs<;hS9f~gAL=mJ46(Pg$uf=K`v3y-3cBepsgVEHR8`wCMT56 zueC8eM%~AqpBNxiU!K_lh8CI!A+t(bH1uUs(u1LQ0J5KLQmmu(O%L2+S6Iv5!So@6v|CSZ{)O#**WrD ze}L+yS@U0F-$A0+8g~}A2r`^QlIo3(0ESUPH8ERX-bjLn^XMFw1 z#rFk&A8wBAIOur}=rz{7#tzkvjkorp`dhU<@50JLQI+{=TPRd~ohJZ9(ru7H(MO11 zWErF@ht5N|LKt)RHYr~F@{B_YPi7BilRB>@8d_>f^y(Lmvshko1@}U|Njt{$HM2Ai z$BtvFYi9>{kh?d2LWJ&lO~^Q9@D{=udY)#UO2v#dM9GK8f;wRFU1n*$vcgAjT_q@j zib1{5@%gjve{m;1v_#X_=w|Z@P&21(R+fr)q+f-) zYLk~;6q6dFr_`(ayQa?1?;Qm~sw_rDXCP4-1~WmS96B5yzwD(B!Ae7DCU=<>IOhGh zZJ8D+5wzdsh$lm{icjG>K<`Ss!jDE&6jIw}P3hj8Av-Fpe=EvMGy@l0ZY(S$^Z1;T zP48@){{Y;ZP2!ACO-)qxBhSOXIWE1eWZfKI&!TVWj+fUbcyM*|Vv&)3`t*Nj7$|5U zQyq(K7vC?Gb78d8(KaJdn~9OV(vTL27nsoht(^1ab5@m)HY!fuZ(KnVKmZZQ+!D_$aU&|9eZ18Ha<2-}ySgjk z7BR9D5v~yDE8C2KbxDN(wL~p}!Bm!W7p^t6V58S-S+g zJbm}28(~VLA8U?J39kixIgtY@Z}G6&g`3J#y4in~gG8O-yI~*PSvS>u??`1tB|*Wq zwnlP<5vv$LG{s4+UN2s4Xz;z z>$NBQJ%v|=VZykLpkHShR)A2@gzXD_WWGve?nX%TjZ9ZsSoiuuM{4ms+I#QCR=8ki zS%VFX5oqG}?AtdmR$Z`X!b_Fu`VeYrs^lExI`CsnXc}WQY=rUv`rfRt_?Zu(cL*8~ zyHoP=q)Wr=U=RxPsXy&j0XPQ-56t}M=pf>l>0&{Ym+8W~HF%;AqM@pVEWTm+ Date: Wed, 14 Dec 2022 23:37:37 +0800 Subject: [PATCH 4/6] =?UTF-8?q?Delete=20=E5=91=A8=E8=BE=89=E6=B1=A0NLP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../\345\221\250\350\276\211\346\261\240NLP" | 1 - 1 file changed, 1 deletion(-) delete mode 100644 "docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" diff --git "a/docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" "b/docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" deleted file mode 100644 index 8b1378917..000000000 --- "a/docs/\347\254\254\345\215\201\347\253\240/\345\221\250\350\276\211\346\261\240NLP" +++ /dev/null @@ -1 +0,0 @@ - From d6f7b3f1adca8eff31e83ccc2c7675dbf8141129 Mon Sep 17 00:00:00 2001 From: Zhikang Niu <73390819+NoFish-528@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:41:57 +0800 Subject: [PATCH 5/6] =?UTF-8?q?=E5=88=A0=E9=99=A4=E7=A9=BA=E6=A0=BC?= =?UTF-8?q?=E7=AD=89=E7=BC=96=E8=BE=91=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../zhcNLP/NLP\345\237\272\347\241\200.md" | 85 +++++-------------- 1 file changed, 20 insertions(+), 65 deletions(-) diff --git "a/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" index 51843e99d..a2ccd861c 100644 --- "a/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" +++ "b/docs/\347\254\254\345\215\201\347\253\240/zhcNLP/NLP\345\237\272\347\241\200.md" @@ -1,23 +1,23 @@ # 词嵌入(概念部分) -###   在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入?
- 计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。
- 我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1]
- 例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
+### 在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入?
+计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。
+我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1]
+例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
一个小朋友刚学会了四个字和词-->[我] [特别] [喜欢] [学习]
- 我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码
- 对于中文 我们先进行分词 我 特别 喜欢 学习
- 那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1]
- 现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1]

- 我们可以思考几个问题:
- 1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。
- 2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。
- 因此我们认为独热编码不是一个很好的词嵌入方法。
+我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码
+对于中文 我们先进行分词 我 特别 喜欢 学习
+那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1]
+现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1]

+我们可以思考几个问题:
+1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。
+2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。
+因此我们认为独热编码不是一个很好的词嵌入方法。
- 我们再来介绍一下 **稠密表示**
+我们再来介绍一下 **稠密表示**
- 稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97] +稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97] # Bag of Words词袋表示 @@ -72,7 +72,7 @@ number_dict - 根据如上形式,我们可以构建一个维度为9的one-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用) +根据如上形式,我们可以构建一个维度为9的one-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用) ```python @@ -104,9 +104,9 @@ bow # N-gram:基于统计的语言模型 - N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。 +N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。 - N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 +N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 **N-gram 模型的优点:** @@ -255,9 +255,9 @@ print([sen.split()[:2] for sen in corpus], '->', [number_dict[n.item()] for n in # Word2Vec模型:主要采用Skip-gram和Cbow两种模式 - 前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。 - skip-gram模型(跳字模型)是用中心词去预测周围词 - cbow模型(连续词袋模型)是用周围词预测中心词 +前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。 +skip-gram模型(跳字模型)是用中心词去预测周围词 +cbow模型(连续词袋模型)是用周围词预测中心词 ```python @@ -363,72 +363,27 @@ for i , label in enumerate(word_list): plt.annotate(label,xy=(x,y),xytext=(5,2),textcoords='offset points',ha='right',va='bottom') plt.show() ``` - 11%|█ | 10615/100000 [00:02<00:24, 3657.80it/s] - epoch: 10000 cost = 1.955088 - - 21%|██ | 20729/100000 [00:05<00:21, 3758.47it/s] - epoch: 20000 cost = 1.673096 - - 30%|███ | 30438/100000 [00:08<00:18, 3710.13it/s] - epoch: 30000 cost = 2.247422 - - 41%|████ | 40638/100000 [00:11<00:15, 3767.87it/s] - epoch: 40000 cost = 2.289902 - - 50%|█████ | 50486/100000 [00:13<00:13, 3713.98it/s] - epoch: 50000 cost = 2.396217 - - 61%|██████ | 60572/100000 [00:16<00:11, 3450.47it/s] - epoch: 60000 cost = 1.539688 - - 71%|███████ | 70638/100000 [00:19<00:07, 3809.11it/s] - epoch: 70000 cost = 1.638879 - - 80%|████████ | 80403/100000 [00:21<00:05, 3740.33it/s] - epoch: 80000 cost = 2.279797 - - 90%|█████████ | 90480/100000 [00:24<00:02, 3680.03it/s] - epoch: 90000 cost = 1.992100 - - 100%|██████████| 100000/100000 [00:27<00:00, 3677.35it/s] - - epoch: 100000 cost = 1.307715 - - - ![png](output_16_20.png) - - - - 打印 - - -
- - -```python - -``` From 018a3b80b93b2ea90f4b6d6fcd0f37ad6a489a4b Mon Sep 17 00:00:00 2001 From: Zhikang Niu <73390819+NoFish-528@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:47:59 +0800 Subject: [PATCH 6/6] =?UTF-8?q?Delete=20NLP=E5=9F=BA=E7=A1=80.ipynb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../NLP\345\237\272\347\241\200.ipynb" | 687 ------------------ 1 file changed, 687 deletions(-) delete mode 100644 "docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" diff --git "a/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" "b/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" deleted file mode 100644 index ba5bcc1f9..000000000 --- "a/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" +++ /dev/null @@ -1,687 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "119ec186", - "metadata": {}, - "source": [ - "# 词嵌入(概念部分)" - ] - }, - { - "cell_type": "markdown", - "id": "f8e5639e", - "metadata": {}, - "source": [ - "###   在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入?
\n", - " 计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。
\n", - " 我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1]
\n", - " 例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
\n", - "\n", - "
一个小朋友刚学会了四个字和词-->[我] [特别] [喜欢] [学习]
\n", - " 我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码
\n", - " 对于中文 我们先进行分词 我 特别 喜欢 学习
\n", - " 那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1]
\n", - " 现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1]

\n", - " 我们可以思考几个问题:
\n", - " 1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。
\n", - " 2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。
\n", - " 因此我们认为独热编码不是一个很好的词嵌入方法。
\n", - "\n", - " 我们再来介绍一下 **稠密表示**
\n", - "\n", - " 稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97]" - ] - }, - { - "cell_type": "markdown", - "id": "4db86da3", - "metadata": {}, - "source": [ - "# Bag of Words词袋表示" - ] - }, - { - "cell_type": "markdown", - "id": "44dc9252", - "metadata": {}, - "source": [ - "  词袋表示顾名思义,我们往一个袋子中装入我们的词汇,构成一个词袋,当我们想表达的时候,我们将其取出,构建词袋的方法可以有如下形式。" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "823f8f2d", - "metadata": {}, - "outputs": [], - "source": [ - "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]#我们的语料库\n", - "word_list = ' '.join(corpus).split()\n", - "word_list = list(sorted(set(word_list)))\n", - "word_dict = {w: i for i, w in enumerate(word_list)}\n", - "number_dict = {i: w for i, w in enumerate(word_list)}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "8eaeb37d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'do': 0,\n", - " 'drinking': 1,\n", - " 'hate': 2,\n", - " 'i': 3,\n", - " 'like': 4,\n", - " 'love': 5,\n", - " 'nlp': 6,\n", - " 'playing': 7,\n", - " 'reading': 8}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "word_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2bf380c8", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{0: 'do',\n", - " 1: 'drinking',\n", - " 2: 'hate',\n", - " 3: 'i',\n", - " 4: 'like',\n", - " 5: 'love',\n", - " 6: 'nlp',\n", - " 7: 'playing',\n", - " 8: 'reading'}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "number_dict" - ] - }, - { - "cell_type": "markdown", - "id": "90e0ef43", - "metadata": {}, - "source": [ - " 根据如上形式,我们可以构建一个维度为9的one-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "9821ed2a", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "voc_size = len(word_dict)\n", - "bow = []\n", - "for i,name in enumerate(word_dict):\n", - " bow.append(np.eye(voc_size)[word_dict[name]])" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "03f1f12f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[array([1., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", - " array([0., 1., 0., 0., 0., 0., 0., 0., 0.]),\n", - " array([0., 0., 1., 0., 0., 0., 0., 0., 0.]),\n", - " array([0., 0., 0., 1., 0., 0., 0., 0., 0.]),\n", - " array([0., 0., 0., 0., 1., 0., 0., 0., 0.]),\n", - " array([0., 0., 0., 0., 0., 1., 0., 0., 0.]),\n", - " array([0., 0., 0., 0., 0., 0., 1., 0., 0.]),\n", - " array([0., 0., 0., 0., 0., 0., 0., 1., 0.]),\n", - " array([0., 0., 0., 0., 0., 0., 0., 0., 1.])]" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "bow" - ] - }, - { - "cell_type": "markdown", - "id": "086a5fd2", - "metadata": {}, - "source": [ - "# N-gram:基于统计的语言模型\n", - " N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。\n", - "\n", - " N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 \n", - "\n", - "**N-gram 模型的优点:**\n", - "\n", - "简单易用,N-gram 模型的概念非常简单,实现起来也很容易。 \n", - "能够捕捉到语言中的相关性,N-gram 模型通过考虑连续出现的 n 个词语来预测下一个词语,因此它能够捕捉到语言中词语之间的相关性。 \n", - "可以使用已有的语料库进行训练,N-gram 模型可以使用已有的大量语料库进行训练,例如 Google 的 N-gram 数据库,这样可以大大提高模型的准确性。 \n", - "\n", - "**N-gram 模型的缺点:**\n", - "\n", - "对于短文本数据集不适用,N-gram 模型需要大量的文本数据进行训练,因此对于短文本数据集可能无法达到较高的准确性。 \n", - "容易受到噪声和语言纠错的影响,N-gram 模型是基于语料库进行训练的,如果语料库中包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。 \n", - "无法捕捉到语言中的非线性关系,N-gram 模型假设语言中的关系是线性的,但事实上语言中可能存在复杂的非线性关系,N-gram 模型无法捕捉到这些关系。" - ] - }, - { - "cell_type": "markdown", - "id": "1f5ad65b", - "metadata": {}, - "source": [ - "# NNLM:前馈神经网络语言模型\n", - " 下面通过前馈神经网络模型来**展示滑动**窗口的使用" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "7bddfa77", - "metadata": {}, - "outputs": [], - "source": [ - "#导入必要的库\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from tqdm import tqdm\n", - "from torch.autograd import Variable\n", - "dtype = torch.FloatTensor" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "29f23588", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['i',\n", - " 'like',\n", - " 'reading',\n", - " 'i',\n", - " 'love',\n", - " 'drinking',\n", - " 'i',\n", - " 'hate',\n", - " 'playing',\n", - " 'i',\n", - " 'do',\n", - " 'nlp']" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]\n", - "\n", - "word_list = ' '.join(corpus).split()\n", - "word_list" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "12b58886", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 1000 cost = 1.010682\n", - "epoch: 2000 cost = 0.695155\n", - "epoch: 3000 cost = 0.597085\n", - "epoch: 4000 cost = 0.531892\n", - "epoch: 5000 cost = 0.376044\n", - "epoch: 6000 cost = 0.118038\n", - "epoch: 7000 cost = 0.077081\n", - "epoch: 8000 cost = 0.053636\n", - "epoch: 9000 cost = 0.038089\n", - "epoch: 10000 cost = 0.027224\n", - "[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'do']] -> ['studying', 'datawhale', 'playing', 'nlp']\n" - ] - } - ], - "source": [ - "#构建我们需要的语料库\n", - "corpus = [\"i like studying\", \"i love datawhale\", \"i hate playing\", \"i do nlp\"]\n", - "\n", - "word_list = ' '.join(corpus).split() #将语料库转化为一个个单词 ,如['i', 'like', 'reading', 'i', ...,'nlp']\n", - "word_list = list(sorted(set(word_list))) #用set去重后转化为链表\n", - "# print(word_list)\n", - "\n", - "word_dict = {w: i for i, w in enumerate(word_list)} #将词表转化为字典 这边是词对应到index\n", - "number_dict = {i: w for i, w in enumerate(word_list)}#这边是index对应到词\n", - "# print(word_dict)\n", - "# print(number_dict)\n", - "\n", - "n_class = len(word_dict) #计算出我们词表的大小,用于后面词向量的构建\n", - "\n", - "m = 2 #词嵌入维度\n", - "n_step = 2 #滑动窗口的大小\n", - "n_hidden = 2 #隐藏层的维度为2\n", - "\n", - "\n", - "def make_batch(sentence): #由于语料库较小,我们象征性将训练集按照批次处理 \n", - " input_batch = []\n", - " target_batch = []\n", - "\n", - " for sen in sentence:\n", - " word = sen.split()\n", - " input = [word_dict[n] for n in word[:-1]]\n", - " target = word_dict[word[-1]]\n", - "\n", - " input_batch.append(input)\n", - " target_batch.append(target)\n", - "\n", - " return input_batch, target_batch\n", - "\n", - "\n", - "class NNLM(nn.Module): #搭建一个NNLM语言模型\n", - " def __init__(self):\n", - " super(NNLM, self).__init__()\n", - " self.embed = nn.Embedding(n_class, m)\n", - " self.W = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n", - " self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n", - "\n", - " self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n", - " self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n", - "\n", - " def forward(self, x):\n", - " x = self.embed(x) # 4 x 2 x 2\n", - " x = x.view(-1, n_step * m)\n", - " tanh = torch.tanh(self.d + torch.mm(x, self.W)) # 4 x 2\n", - " output = self.b + torch.mm(tanh, self.U)\n", - " return output\n", - "\n", - "model = NNLM()\n", - "\n", - "criterion = nn.CrossEntropyLoss() #损失函数的设置\n", - "optimizer = optim.Adam(model.parameters(), lr=0.001) #优化器的设置\n", - "\n", - "input_batch, target_batch = make_batch(corpus) #训练集和标签值\n", - "input_batch = Variable(torch.LongTensor(input_batch))\n", - "target_batch = Variable(torch.LongTensor(target_batch))\n", - "\n", - "for epoch in range(10000): #训练过程\n", - " optimizer.zero_grad()\n", - "\n", - " output = model(input_batch) # input: 4 x 2\n", - "\n", - " loss = criterion(output, target_batch)\n", - "\n", - " if (epoch + 1) % 1000 == 0:\n", - " print('epoch:', '%04d' % (epoch + 1), 'cost = {:.6f}'.format(loss.item()))\n", - "\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - "predict = model(input_batch).data.max(1, keepdim=True)[1]#模型预测过程\n", - "\n", - "print([sen.split()[:2] for sen in corpus], '->', [number_dict[n.item()] for n in predict.squeeze()])" - ] - }, - { - "cell_type": "markdown", - "id": "93d8cd2f", - "metadata": {}, - "source": [ - "# Word2Vec模型:主要采用Skip-gram和Cbow两种模式\n", - " 前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。\n", - " skip-gram模型(跳字模型)是用中心词去预测周围词\n", - " cbow模型(连续词袋模型)是用周围词预测中心词" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "066f68a0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 11%|█ | 10615/100000 [00:02<00:24, 3657.80it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 10000 cost = 1.955088\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 21%|██ | 20729/100000 [00:05<00:21, 3758.47it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 20000 cost = 1.673096\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 30%|███ | 30438/100000 [00:08<00:18, 3710.13it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 30000 cost = 2.247422\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 41%|████ | 40638/100000 [00:11<00:15, 3767.87it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 40000 cost = 2.289902\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 50%|█████ | 50486/100000 [00:13<00:13, 3713.98it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 50000 cost = 2.396217\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 61%|██████ | 60572/100000 [00:16<00:11, 3450.47it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 60000 cost = 1.539688\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 71%|███████ | 70638/100000 [00:19<00:07, 3809.11it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 70000 cost = 1.638879\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 80%|████████ | 80403/100000 [00:21<00:05, 3740.33it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 80000 cost = 2.279797\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 90%|█████████ | 90480/100000 [00:24<00:02, 3680.03it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 90000 cost = 1.992100\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 100000/100000 [00:27<00:00, 3677.35it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 100000 cost = 1.307715\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD6CAYAAACiefy7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnsUlEQVR4nO3de3hU1b3/8fc34U64KagRqQGLyC2BECSAXGysYKmAF4poRVFLU8VDbbViFZtqPW0PtB6pYkRBQFGOCoIoFn8IFDQIBAgIyr1RrhLBhATDJWT9/pghTUICCZnMTGY+r+fJk9lrr9nrO0PyYWftPXubcw4REQl9EYEuQERE/EOBLyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+FKtzCzGzDZVov9QM+tQnTWJhCsL5vPwmzdv7mJiYgJdhlTB8ePH2bFjBx07dqxQ/8zMTJo0aUKzZs2quTKR0LR27dpvnXMtylpXy9/FVEZMTAzp6emBLkOqIDMzkxtuuIGuXbuSlpZGy5YtmT9/Pq+//jpTpkzhxIkT/PCHP+S1114jIyODn/70pxQWFlJQUMCcOXMAeOCBB8jKyqJBgwa8/PLLXHXVVQF+VSLBy8y+Km+dpnSk2m3fvp0HHniAzZs307RpU+bMmcPNN9/MmjVr2LBhA+3bt2fq1Kn06tWLwYMHM2HCBDIyMrjiiisYPXo0//jHP1i7di0TJ07k/vvvD/TLEamxgnoPX0JD69at6dKlCwDdunUjMzOTTZs28cQTT5CdnU1eXh4DBgw443l5eXmkpaUxbNiworbjx4/7q2yRkKPAl2pXt27doseRkZHk5+dz9913M2/ePOLi4pg+fTrLli0743mFhYU0bdqUjIwM/xUrEsI0pSMBkZubS3R0NCdPnmTWrFlF7Y0aNSI3NxeAxo0b07p1a95++20AnHNs2LAhIPWKhAIFvgTE008/TY8ePfjxj39c4iDsbbfdxoQJE+jatSs7d+5k1qxZTJ06lbi4ODp27Mj8+fMDWLVIzRbUp2UmJCQ4naUTvuat38uERVvZl53PpU3r88iAdgzt2jLQZYkENTNb65xLKGud5vAlKM1bv5fH5n5O/slTAOzNzuexuZ8DKPRFzpOmdCQoTVi0tSjsT8s/eYoJi7YGqCKRmk+BL0FpX3Z+pdpF5NwU+BKULm1av1LtInJuCnwJSo8MaEf92pEl2urXjuSRAe0CVJFIzaeDthKUTh+Y1Vk6Ir6jwJegNbRrSwW8iA/5ZErHzKaZ2cHyrntuHpPMbIeZbTSzeF+MKyIiFeerOfzpwMCzrL8BaOv9Gg286KNxRUSkgnwS+M655cDhs3QZAsx0Hp8BTc0s2hdji4hIxfjrLJ2WwO5iy3u8bWcws9Fmlm5m6VlZWX4pTkQkHPgr8K2MtjIv4uOcm+KcS3DOJbRoUeZdukRE5Dz4K/D3AK2KLV8G7PPT2CIigv8C/z1gpPdsnUQgxzm3309ji4gIPjoP38zeBPoDzc1sD/AHoDaAcy4VWAj8BNgBfA+M8sW4IiJScT4JfOfciHOsd8ADvhhLRETOj66lIyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+BKSMjMz6dSp03k//+677+add97xYUUigafAFymloKAg0CWIVAsFvoSsgoIC7rrrLmJjY7n11lv5/vvveeqpp+jevTudOnVi9OjReD4iAv379+f3v/89/fr147nnniuxnfHjx3P33XdTWFgYiJch4jMKfAlZW7duZfTo0WzcuJHGjRszefJkxowZw5o1a9i0aRP5+fm8//77Rf2zs7P517/+xW9/+9uitt/97nccPHiQV199lYgI/bpIzaafYAlZrVq1onfv3gD8/Oc/55NPPmHp0qX06NGDzp07s2TJEjZv3lzUf/jw4SWe//TTT5Odnc1LL72EWVkXfBWpWXRPWwlZpUPazLj//vtJT0+nVatWpKSkcOzYsaL1DRs2LNG/e/furF27lsOHD3PBBRf4pWaR6qQ9fAlZX3/9NStXrgTgzTff5JprrgGgefPm5OXlnfMsnIEDBzJu3DgGDRpEbm5utdcrUt20hy8hq3379syYMYNf/vKXtG3bll/96ld89913dO7cmZiYGLp3737ObQwbNozc3FwGDx7MwoULqV+/vh8qF6kedvoshWCUkJDg0tPTA12GhJkPdn3Ac+ue48DRA1zS8BLGxo9lUJtBgS5LpELMbK1zLqGsddrDFynmg10fkJKWwrFTnrn9/Uf3k5KWAqDQlxpPc/gixTy37rmisD/t2KljPLfuuXKeIVJzKPBFijlw9ECl2kVqEgW+SDGXNLykUu0iNYkCX6SYsfFjqRdZr0Rbvch6jI0fG6CKRHxHgS9nlZqaysyZM32yrZiYGL799lufbKu6DGoziJReKUQ3jMYwohtGk9IrRQdsJSToLB05q+Tk5ECX4HeD2gxSwEtI8skevpkNNLOtZrbDzMaVsb6JmS0wsw1mttnMRvliXDk/Q4cOpVu3bnTs2JEpU6YAEBUVxeOPP05cXByJiYl88803AKSkpDBx4kTAc0XJhx56iL59+9K+fXvWrFnDzTffTNu2bXniiSfOun0RCbwqB76ZRQIvADcAHYARZtahVLcHgC+cc3FAf+BvZlanqmPL+Zk2bRpr164lPT2dSZMmcejQIY4ePUpiYiIbNmygb9++vPzyy2U+t06dOixfvpzk5GSGDBnCCy+8wKZNm5g+fTqHDh0qd/siEni+2MO/GtjhnNvlnDsBzAaGlOrjgEbmuZpVFHAY0F0mAmTSpElFe/K7d+9m+/bt1KlTh5/+9KcAdOvWjczMzDKfO3jwYAA6d+5Mx44diY6Opm7durRp04bdu3eXu/2aqlevXoEuQcRnfDGH3xLYXWx5D9CjVJ/ngfeAfUAjYLhzrsy7SZjZaGA0wA9+8AMflCfFLVu2jMWLF7Ny5UoaNGhA//79OXbsGLVr1y66umRkZGS5d32qW7cuABEREUWPTy8XFBSUu/2aKi0tLdAliPiML/bwy7pQeOkL9AwAMoBLgS7A82bWuKyNOeemOOcSnHMJLVq08EF5UlxOTg7NmjWjQYMGbNmyhc8++6xGbd/foqKiAl1CtSp+jEZCny8Cfw/QqtjyZXj25IsbBcx1HjuAfwNX+WBsqaSBAwdSUFBAbGws48ePJzExsUZtX0TOX5WvlmlmtYBtQBKwF1gD3O6c21ysz4vAN865FDO7GFgHxDnnznpStq6WWXMdXX+QI4syOZV9nMimdWk8IIaGXS8KdFmVFhUVRV5eXqDL8KlnnnmGmTNn0qpVK1q0aEG3bt247rrrSE5O5vvvv+eKK65g2rRpNGvWjDVr1nDvvffSsGFDrrnmGj788EM2bdoU6JcgZ3G2q2VWeQ/fOVcAjAEWAV8CbznnNptZspmdPon7aaCXmX0OfAw8eq6wl5rr6PqDZM/dzqns4wCcyj5O9tztHF1/MMCVydq1a5k9ezbr169n7ty5rFmzBoCRI0fy17/+lY0bN9K5c2f++Mc/AjBq1ChSU1NZuXIlkZGRgSxdfMAn5+E75xY65650zl3hnHvG25bqnEv1Pt7nnLveOdfZOdfJOfe6L8YV/5g+fTpjxoypcP8jizJxJ0sek3cnCzmyKNPHlUllrVixgptuuokGDRrQuHFjBg8ezNGjR8nOzqZfv34A3HXXXSxfvpzs7Gxyc3OLzlS6/fbbA1m6+IAurSA+d3rPvqLt4l8VvSF7MN8cSc6PAj/Elfep2t/+9rfEx8eTlJREVlYW4Pkk7a9//Wt69epFp06dWL169Rnby8rK4pZbbqF79+50796dTz/99Iw+kU3rntF2tvagsfEteLYTpDT1fN/4VsjN3/ft25d3332X/Px8cnNzWbBgAQ0bNqRZs2asWLECgNdee41+/frRrFkzGjVqVHSm1ezZswNZuviAAj/Elfep2vj4eNatW0e/fv2K5msBjh49SlpaGpMnT+aee+45Y3tjx47loYceYs2aNcyZM4f77rvvjD6NB8RgtUv+aFntCBoPiPH56/OZjW/Bgv+CnN2A83xf8F+e9hASHx/P8OHD6dKlC7fccgt9+vQBYMaMGTzyyCPExsaSkZHBk08+CcDUqVMZPXo0PXv2xDlHkyZNAlm+VJEunhbiJk2axLvvvgtQ9KnXiIgIhg8fDsDPf/5zbr755qL+I0aMADx7gkeOHCE7O7vE9hYvXswXX3xRtHzkyBFyc3Np1KhRUdvps3Fq1Fk6Hz8FJ/NLtp3M97TH/iwwNVWTxx9/nMcff/yM9tKfmdi4cSNLly7l5ptvpkmTJmRmZpKQUObJH1JDKPBDWEU/9Vp8Trf0/G7p5cLCQlauXEn9+vXPOnbDrhcFd8CXlrOncu0hbuPGjSxYsID169fzySefUFhYSLNmzXjppZcCXZpUgaZ0Qlh5n3otLCzknXfeAeCNN97gmmuuKXrO//3f/wHwySef0KRJkzP+hL/++ut5/vnni5YzMjKq+VX4SZPLKtce4j7++GNOnjxJp06dSE5O5v7772fEiBGsX78+0KVJFSjwQ1h5n3pt2LAhmzdvplu3bixZsqRovhagWbNm9OrVi+TkZKZOnXrGNidNmkR6ejqxsbF06NCB1NRUv72eapX0JNQu9VdL7fqe9jATFRVFTk4Oubm5vPWW5xhGRkYGCxcuJCcnJ8DVSVVU+ZO21UmftK0e5X16tH///kycOLHcedovVyxlxeyZ5B76lkYXNqfPbSNp3+fa6i7Xfza+5Zmzz9nj2bNPejLk5u8rIioqiqeffrpEuGdkZLBv3z5GjBjBQw89FMDq5Fyq9ZO2Eh6+XLGUj6Y8T+63WeAcud9m8dGU5/lyxdJAl+Y7sT+DhzZBSrbnexiG/WlJSUnk5eUxefLkoraIiAiSkpL44IMP6NmzJ99++y0fffQRPXv2JD4+nmHDhoXcaayhRoEfhsr7pVy2bFm5e/crZs+k4ETJD04VnDjOitm+ud+tBJfY2FiSkpKKLqdQv359YmJi2LlzJ3/5y19YuHAhAH/6059YvHgx69atIyEhgb///e+BLFvOQWfpSIXkHir70kfltUvN1759ey688EJSUlKYPn06EyZMIDMzk48++ojGjRvz/vvv88UXX9C7d28ATpw4Qc+ePQNctZyNAl8qpNGFzT3TOWW0S3ho06YNu3btYtu2bSQkJOCc48c//jFvvvlmoEuTCtKUjlRIn9tGUqtOyUsj1KpTlz63jQxQReJvl19+OXPnzmXkyJFs3ryZxMREPv30U3bs2AHA999/z7Zt2wJcpZyNAl8qpH2fa7l+9BgaNW8BZjRq3oLrR48JrbN05JzatWvHrFmzGDZsGEeOHGH69OmMGDGC2NhYEhMT2bJlS6BLlLPQaZkiUmn7D8xn186JHDu+n3p1o2lzxcNEXzIk0GUJZz8tU3P4IlIp+w/MZ8uWxyks9Fx76NjxfWzZ4rk2j0I/uGlKR0QqZdfOiUVhf1phYT67dupm6MFOgS8ilXLs+P5KtUvwUOCLSKXUqxtdqfbzkZmZyRtvvOGz7YmHAl9EKqXNFQ8TEVHyQnMREfVpc8XDPhtDgV89FPgiUinRlwzhqqueoV7dSwHju8MX8Iv7shn/xAI6derEHXfcweLFi+nduzdt27Zl9erVHD16lHvuuYfu3bvTtWtX5s+fD3iCvU+fPsTHxxMfH09aWhoA48aNY8WKFXTp0oVnn302gK82xDjnqvwFDAS2AjuAceX06Q9kAJuBf1Vku926dXMiEtz+/e9/u8jISLdx40Z36tQpFx8f70aNGuUKCwvdvHnz3JAhQ9xjjz3mXnvtNeecc999951r27aty8vLc0ePHnX5+fnOOee2bdvmTv/OL1261A0aNChgr6kmA9JdOZla5dMyzSwSeAH4MbAHWGNm7znnvijWpykwGRjonPvazGrQrZBE5Fxat25N586dAejYsSNJSUmYGZ07dyYzM5M9e/bw3nvvMXGi50yeY8eO8fXXX3PppZcyZswYMjIyiIyM1Cd1q5kvzsO/GtjhnNsFYGazgSHAF8X63A7Mdc59DeCcO+iDcUUkSNSt+5/LbkRERBQtR0REUFBQQGRkJHPmzKFdu3YlnpeSksLFF1/Mhg0bKCwspF69en6tO9z4Yg6/JbC72PIeb1txVwLNzGyZma01s3IvwGJmo80s3czSs7LOvFiXiNQ8AwYM4B//+Mfp6d2iWyXm5OQQHR1NREQEr732GqdOnQKgUaNG5ObmBqxeX5g+fTpjxowJdBkl+CLwrYy20tdrqAV0AwYBA4DxZnZlWRtzzk1xziU45xJatGjhg/JEJNDGjx/PyZMniY2NpVOnTowfPx6A+++/nxkzZpCYmMi2bdto2LAh4Lkef61atYiLiwvYQVvnHIWFhQEZu7pU+Vo6ZtYTSHHODfAuPwbgnPtzsT7jgHrOuRTv8lTgn865t8+27bKupTNp0iRefPFFDhw4wKOPPsq4cePKfO706dNJT08vccNtEQle89bvZcKirezLzufSpvV5ZEA7hnYtPVlQvTIzM7nhhhu49tprWblyJUOHDuX999/n+PHj3HTTTfzxj38EYOjQoezevZtjx44xduxYRo8eDcCrr77Kn//8Z6Kjo7nyyiupW7eu3zOouq+lswZoa2atgb3AbXjm7IubDzxvZrWAOkAP4Lz+2548eTIffvghrVu3rkLJIhJM5q3fy2NzPyf/pGdKZ292Po/N/RzA76G/detWXn31VYYOHco777zD6tWrcc4xePBgli9fTt++fZk2bRoXXHAB+fn5dO/enVtuuYUTJ07whz/8gbVr19KkSROuvfZaunbt6tfaz6XKUzrOuQJgDLAI+BJ4yzm32cySzSzZ2+dL4J/ARmA18IpzblNlx0pOTmbXrl0MHjyYZ599tmh+7O2336ZTp07ExcXRt2/fov779u1j4MCBtG3blt/97ndVfakiUk0mLNpaFPan5Z88xYRFW/1ey+WXX05iYiIfffQRH330EV27diU+Pp4tW7awfft2wDPTEBcXR2JiIrt372b79u2sWrWK/v3706JFC+rUqcPw4cP9Xvu5+ORqmc65hcDCUm2ppZYnABOqMk5qair//Oc/Wbp0Ke+//35R+1NPPcWiRYto2bIl2dnZRe0ZGRmsX7+eunXr0q5dOx588EFatWpVlRJEpBrsy86vVHt1On0cwTnHY489xi9/+csS65ctW8bixYtZuXIlDRo0oH///hw7dgwAs7IOaQaPkPikbe/evbn77rt5+eWXi47yAyQlJdGkSRPq1atHhw4d+OqrrwJYpYiU59Km9SvV7g8DBgxg2rRp5OXlAbB3714OHjxITk4OzZo1o0GDBmzZsoXPPvsMgB49erBs2TIOHTrEyZMnefvtsx6iDIiQuB5+amoqq1at4oMPPqBLly5kZGQAJc8NjoyMpKCgIEAVisjZPDKgXYk5fID6tSN5ZEC7szyrel1//fV8+eWXRTdmj4qK4vXXX2fgwIGkpqYSGxtLu3btSExMBCA6OpqUlBR69uxJdHQ08fHxJXZAg0FIBP7OnTvp0aMHPXr0YMGCBezevfvcTxKRoHH6wGygz9KJiYlh06b/HF4cO3YsY8eOPaPfhx9+eEZbzoIFXPPa68yPrEWtyFpcdN11NLnxxmqtt7JCIvAfeeQRtm/fjnOOpKQk4uLiivbyRaRmGNq1pd8D3ldyFixg//gncd65/IJ9+9g//kmAoAp93dNWRKSKtv8oiYJ9+85or3XppbRd8rFfawm7e9rOOXCYP+/az97jJ2lZtzaPtYnmlksuCHRZIhKiCvaXfbev8toDJSTO0iluzoHDPLx1N3uOn8QBe46f5OGtu5lz4HCgSxOREFUruuy7fZXXHighF/h/3rWf/MKS01T5hY4/7wqu/2lFJHRc9NCvsVJX+rR69bjooV8HpqByhNyUzt7jJyvVLiJSVacPzB589n8p2L+fWtHRXPTQr4PqgC2EYOC3rFubPWWEe8u6tQNQjYiEiyY33hh0AV9ayE3pPNYmmvoRJT/eXD/CeKxNcM2liYj4W8jt4Z8+G0dn6YiIlBRygQ+e0FfAi4iUFHJTOiIiUjYFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGzcWfp1N7NTZnarL8YVEZGKq3Lgm1kk8AJwA9ABGGFmHcrp91dgUVXHFBGRyvPFHv7VwA7n3C7n3AlgNjCkjH4PAnOAgz4YU0REKskXgd8S2F1seY+3rYiZtQRuAlLPtTEzG21m6WaWnpWV5YPyREQEfBP4VkZb6Tuj/y/wqHPu1Lk25pyb4pxLcM4ltGjRwgfliYgI+OZqmXuAVsWWLwNK3749AZhtZgDNgZ+YWYFzbp4PxpcaIDU1lQYNGjBy5MhAlyIStnwR+GuAtmbWGtgL3AbcXryDc6716cdmNh14X2EfXpKTkwNdgkjYq/KUjnOuABiD5+ybL4G3nHObzSzZzPRbXgNlZmZy1VVXcd9999GpUyfuuOMOFi9eTO/evWnbti2rV6/m8OHDDB06lNjYWBITE9m4cSOFhYXExMSQnZ1dtK0f/vCHfPPNN6SkpDBx4kQAdu7cycCBA+nWrRt9+vRhy5YtAXqlIuHFJzdAcc4tBBaWaivzAK1z7m5fjCnVa8eOHbz99ttMmTKF7t2788Ybb/DJJ5/w3nvv8d///d+0atWKrl27Mm/ePJYsWcLIkSPJyMhgyJAhvPvuu4waNYpVq1YRExPDxRdfXGLbo0ePJjU1lbZt27Jq1Sruv/9+lixZEqBXKhI+QvKOV1J1rVu3pnPnzgB07NiRpKQkzIzOnTuTmZnJV199xZw5cwD40Y9+xKFDh8jJyWH48OE89dRTjBo1itmzZzN8+PAS283LyyMtLY1hw4YVtR0/ftx/L0wkjCnwpUx169YtehwREVG0HBERQUFBAbVqnfmjY2b07NmTHTt2kJWVxbx583jiiSdK9CksLKRp06ZkZGRUa/0iciZdS0fOS9++fZk1axYAy5Yto3nz5jRu3Bgz46abbuI3v/kN7du358ILLyzxvMaNG9O6dWvefvttAJxzbNiwwe/1i4QjBb6cl5SUFNLT04mNjWXcuHHMmDGjaN3w4cN5/fXXz5jOOW3WrFlMnTqVuLg4OnbsyPz58/1VtkhYM+dKf0YqeCQkJLj09PRAlyE+9MGuD3hu3XMcOHqASxpewtj4sQxqMyjQZYmEDDNb65xLKGud5vDFbz7Y9QEpaSkcO3UMgP1H95OSlgKg0BfxA03piN88t+65orA/7dipYzy37rkAVSQSXhT44jcHjh6oVLuI+JYCX/zmkoaXVKpdRHxLgS9+MzZ+LPUi65VoqxdZj7HxYwNUkUh40UFb8ZvTB2Z1lo5IYCjwxa8GtRmkgBcJEE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJnwS+mQ00s61mtsPMxpWx/g4z2+j9SjOzOF+MKyIiFVflwDezSOAF4AagAzDCzDqU6vZvoJ9zLhZ4GphS1XFFRKRyfLGHfzWwwzm3yzl3ApgNDCnewTmX5pz7zrv4GXCZD8YVEQlZy5YtIy0tzafb9EXgtwR2F1ve420rz73Ah+WtNLPRZpZuZulZWVk+KE9EpOYJ1sC3MtrKvDO6mV2LJ/AfLW9jzrkpzrkE51xCixYtfFCeiEjwmDlzJrGxscTFxXHnnXeyYMECevToQdeuXbnuuuv45ptvyMzMJDU1lWeffZYuXbqwYsUKn4zti8sj7wFaFVu+DNhXupOZxQKvADc45w75YFwRkRpl8+bNPPPMM3z66ac0b96cw4cPY2Z89tlnmBmvvPIK//M//8Pf/vY3kpOTiYqK4uGHH/bZ+L4I/DVAWzNrDewFbgNuL97BzH4AzAXudM5t88GYIiI1zpIlS7j11ltp3rw5ABdccAGff/45w4cPZ//+/Zw4cYLWrVtX2/hVntJxzhUAY4BFwJfAW865zWaWbGbJ3m5PAhcCk80sw8zSqzquiEhN45zDrOQs+IMPPsiYMWP4/PPPeemllzh27Fi1je+T8/Cdcwudc1c6565wzj3jbUt1zqV6H9/nnGvmnOvi/UrwxbgiIjVJUlISb731FocOeWa1Dx8+TE5ODi1bes5zmTFjRlHfRo0akZub69Px9UlbERE/6dixI48//jj9+vUjLi6O3/zmN6SkpDBs2DD69OlTNNUDcOONN/Luu+/69KCtOVfmCTVBISEhwaWna/ZHRMLDtlUHWDl/J3mHjxN1QV16DrmCK3tcUqltmNna8mZRdBNzEZEgsG3VAZbO2kLBiUIA8g4fZ+msLQCVDv3yaEpHRCQIrJy/syjsTys4UcjK+Tt9NoYCX0QkCOQdPl6p9vOhwBcRCQJRF9StVPv5UOCLiASBnkOuoFadkpFcq04EPYdc4bMxdNBWRCQInD4wW9WzdM5GgS8iEiSu7HGJTwO+NE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGxcGevNzCZ51280s3hfjCsiIhVX5cA3s0jgBeAGoAMwwsw6lOp2A9DW+zUaeLGq44qISOX4Yg//amCHc26Xc+4EMBsYUqrPEGCm8/gMaGpm0T4YW0REKsgXgd8S2F1seY+3rbJ9ADCz0WaWbmbpWVlZPihPRETAN4FvZbS58+jjaXRuinMuwTmX0KJFiyoXJyIiHr4I/D1Aq2LLlwH7zqOPiIhUI18E/hqgrZm1NrM6wG3Ae6X6vAeM9J6tkwjkOOf2+2BsERGpoCrf09Y5V2BmY4BFQCQwzTm32cySvetTgYXAT4AdwPfAqKqOKyIileOTm5g75xbiCfXibanFHjvgAV+MJSIi50eftBURCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMFGlwDezC8zs/5nZdu/3ZmX0aWVmS83sSzPbbGZjqzKmiIicn6ru4Y8DPnbOtQU+9i6XVgD81jnXHkgEHjCzDlUcV0REKqmqgT8EmOF9PAMYWrqDc26/c26d93Eu8CXQsorjiohIJVU18C92zu0HT7ADF52ts5nFAF2BVWfpM9rM0s0sPSsrq9IF9erVq9LPEREJB7XO1cHMFgOXlLHq8coMZGZRwBzg1865I+X1c85NAaYAJCQkuMqMAZCWllbZp4iIhIVzBr5z7rry1pnZN2YW7Zzbb2bRwMFy+tXGE/aznHNzz7vaCoiKiiIvL4/9+/czfPhwjhw5QkFBAS+++CJ9+vSpzqFFRIJaVad03gPu8j6+C5hfuoOZGTAV+NI59/cqjldhb7zxBgMGDCAjI4MNGzbQpUsXfw0tIhKUzrmHfw5/Ad4ys3uBr4FhAGZ2KfCKc+4nQG/gTuBzM8vwPu/3zrmFVRz7rLp3784999zDyZMnGTp0qAJfRMJelfbwnXOHnHNJzrm23u+Hve37vGGPc+4T55w552Kdc128X9Ua9gB9+/Zl+fLltGzZkjvvvJOZM2dW95AiIkEtZD9p+9VXX3HRRRfxi1/8gnvvvZd169YFuiQRkYCq6pRO0Fq2bBkTJkygdu3aREVFaQ9fRMJeyAV+Xto0eLYTd+Xs4a57L4OkxyD2Z4EuS0Qk4EIr8De+BQv+C07me5ZzdnuWQaEvImEvtObwP37qP2F/2sl8T7uISJgLrcDP2VO5dhGRMBJagd/kssq1i4iEkdAK/KQnoXb9km2163vaRUTCXGgFfuzP4MZJ0KQVYJ7vN07SAVsREULtLB3whLsCXkTkDKG1hy8iIuVS4IuIhAkFvohImFDgi4iECQW+iEiYMOcqfdtYvzGzLOArH22uOfCtj7blS6qrclRX5QRrXRC8tdX0ui53zrUoa0VQB74vmVm6cy4h0HWUproqR3VVTrDWBcFbWyjXpSkdEZEwocAXEQkT4RT4UwJdQDlUV+WorsoJ1rogeGsL2brCZg5fRCTchdMevohIWFPgi4iEiZAKfDMbaGZbzWyHmY0rY/0dZrbR+5VmZnFBVNsQb10ZZpZuZtcEQ13F+nU3s1Nmdmsw1GVm/c0sx/t+ZZiZX256UJH3y1tbhpltNrN/BUNdZvZIsfdqk/ff8oIgqKuJmS0wsw3e92tUdddUwbqamdm73t/J1WbWyU91TTOzg2a2qZz1ZmaTvHVvNLP4Sg3gnAuJLyAS2Am0AeoAG4AOpfr0App5H98ArAqi2qL4zzGVWGBLMNRVrN8SYCFwazDUBfQH3g/Cn7GmwBfAD7zLFwVDXaX63wgsCYa6gN8Df/U+bgEcBuoEQV0TgD94H18FfOynn7G+QDywqZz1PwE+BAxIrGyGhdIe/tXADufcLufcCWA2MKR4B+dcmnPuO+/iZ4C/7n1YkdrynPdfFGgI+ONo+jnr8noQmAMc9ENNlanL3ypS1+3AXOfc1wDOOX+8Z5V9v0YAbwZJXQ5oZGaGZ6fnMFAQBHV1AD4GcM5tAWLM7OJqrgvn3HI870F5hgAzncdnQFMzi67o9kMp8FsCu4st7/G2ledePP9T+kOFajOzm8xsC/ABcE8w1GVmLYGbgFQ/1FPhurx6eqcCPjSzjkFS15VAMzNbZmZrzWxkkNQFgJk1AAbi+Q88GOp6HmgP7AM+B8Y65wqDoK4NwM0AZnY1cDn+20E8m8rmXAmhFPhWRluZe8lmdi2ewH+0WisqNmQZbWfU5px71zl3FTAUeLq6i6Jidf0v8Khz7lT1l1OkInWtw3PNkDjgH8C86i6KitVVC+gGDAIGAOPN7MogqOu0G4FPnXNn24v0lYrUNQDIAC4FugDPm1nj6i2rQnX9Bc9/3Bl4/sJdT/X/5VERlfm3PkMo3eJwD9Cq2PJlePYaSjCzWOAV4Abn3KFgqu0059xyM7vCzJo756rzIk4VqSsBmO35i5vmwE/MrMA5Ny+QdTnnjhR7vNDMJgfJ+7UH+NY5dxQ4ambLgThgW4DrOu02/DOdAxWraxTwF+905g4z+zeeOfPVgazL+/M1CjwHSoF/e78CrVJZcgZ/HIjw08GOWsAuoDX/ORDTsVSfHwA7gF5BWNsP+c9B23hg7+nlQNZVqv90/HPQtiLv1yXF3q+rga+D4f3CMz3xsbdvA2AT0CnQdXn7NcEzP9ywuv8NK/F+vQikeB9f7P25bx4EdTXFe/AY+AWeefNqf8+848VQ/kHbQZQ8aLu6MtsOmT1851yBmY0BFuE5Cj/NObfZzJK961OBJ4ELgcnePdYC54er4lWwtluAkWZ2EsgHhjvvv3CA6/K7CtZ1K/ArMyvA837dFgzvl3PuSzP7J7ARKARecc6VeYqdP+vydr0J+Mh5/vqodhWs62lgupl9jifEHnXV+1daRetqD8w0s1N4zrq6tzprOs3M3sRzBlpzM9sD/AGoXayuhXjO1NkBfI/3r5AKb7+af0dERCRIhNJBWxEROQsFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhIn/DyCri3Zc6/JlAAAAAElFTkSuQmCC", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "打印\n" - ] - }, - { - "data": { - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torch.autograd import variable\n", - "import numpy as np\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "from tqdm import tqdm\n", - "\n", - "dtype = torch.FloatTensor\n", - "#我们使用的语料库 \n", - "sentences = ['i like dog','i like cat','i like animal','dog is animal','cat is animal',\n", - " 'dog like meat','cat like meat','cat like fish','dog like meat','i like apple',\n", - " 'i hate apple','i like movie','i like read','dog like bark','dog like cat']\n", - "\n", - "\n", - "\n", - "word_sequence = ' '.join(sentences).split() #将语料库的每一句话的每一个词转化为列表 \n", - "#print(word_sequence)\n", - "\n", - "word_list = list(set(word_sequence)) #构建我们的词表 \n", - "#print(word_list)\n", - "\n", - "#word_voc = list(set(word_sequence)) \n", - "\n", - "#接下来对此表中的每一个词编号 这就用到了我们之前提到的one-hot编码 \n", - "\n", - "#词典 词对应着编号\n", - "word_dict = {w:i for i,w in enumerate(word_list)}\n", - "#print(word_dict)\n", - "#编号对应着词\n", - "index_dict = {i:w for w,i in enumerate(word_list)}\n", - "#print(index_dict)\n", - "\n", - "\n", - "batch_size = 2\n", - "voc_size = len(word_list)\n", - "\n", - "skip_grams = []\n", - "for i in range(1,len(word_sequence)-1,3):\n", - " target = word_dict[word_sequence[i]] #当前词对应的id\n", - " context = [word_dict[word_sequence[i-1]],word_dict[word_sequence[i+1]]] #两个上下文词对应的id\n", - "\n", - " for w in context:\n", - " skip_grams.append([target,w])\n", - "\n", - "embedding_size = 10 \n", - "\n", - "\n", - "class Word2Vec(nn.Module):\n", - " def __init__(self):\n", - " super(Word2Vec,self).__init__()\n", - " self.W1 = nn.Parameter(torch.rand(len(word_dict),embedding_size)).type(dtype) \n", - " #将词的one-hot编码对应到词向量中\n", - " self.W2 = nn.Parameter(torch.rand(embedding_size,voc_size)).type(dtype)\n", - " #将词向量 转化为 输出 \n", - " def forward(self,x):\n", - " hidden_layer = torch.matmul(x,self.W1)\n", - " output_layer = torch.matmul(hidden_layer,self.W2)\n", - " return output_layer\n", - "\n", - "\n", - "model = Word2Vec()\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.Adam(model.parameters(),lr=1e-5)\n", - "\n", - "#print(len(skip_grams))\n", - "#训练函数\n", - "\n", - "def random_batch(data,size):\n", - " random_inputs = []\n", - " random_labels = []\n", - " random_index = np.random.choice(range(len(data)),size,replace=False)\n", - " \n", - " for i in random_index:\n", - " random_inputs.append(np.eye(voc_size)[data[i][0]]) #从一个单位矩阵生成one-hot表示\n", - " random_labels.append(data[i][1])\n", - " \n", - " return random_inputs,random_labels\n", - "\n", - "for epoch in tqdm(range(100000)):\n", - " input_batch,target_batch = random_batch(skip_grams,batch_size) # X -> y\n", - " input_batch = torch.Tensor(input_batch)\n", - " target_batch = torch.LongTensor(target_batch)\n", - "\n", - " optimizer.zero_grad()\n", - "\n", - " output = model(input_batch)\n", - "\n", - " loss = criterion(output,target_batch)\n", - " if((epoch+1)%10000==0):\n", - " print(\"epoch:\",\"%04d\" %(epoch+1),'cost =' ,'{:.6f}'.format(loss))\n", - "\n", - " loss.backward() \n", - " optimizer.step()\n", - "\n", - "for i , label in enumerate(word_list):\n", - " W1,_ = model.parameters()\n", - " x,y = float(W1[i][0]),float(W1[i][1])\n", - " plt.scatter(x,y)\n", - " plt.annotate(label,xy=(x,y),xytext=(5,2),textcoords='offset points',ha='right',va='bottom')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1edccf25", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytorch", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.9 (default, Aug 31 2020, 12:42:55) \n[GCC 7.3.0]" - }, - "vscode": { - "interpreter": { - "hash": "7648c2b9d25760d0d65f53f9b9a34de48caa24d8265d64b0ff81e2f2641d528d" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}