Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

#使用Transformer进行文本分类#代码提交 #937

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open

#使用Transformer进行文本分类#代码提交 #937

wants to merge 19 commits into from

Conversation

YinHang2515
Copy link

@CLAassistant
Copy link

CLAassistant commented Nov 30, 2020

CLA assistant check
All committers have signed the CLA.

"source": [
"import paddle\n",
"import paddle.nn as nn\n",
"import paddle.fluid.dygraph as dg\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paddle2.0不建议使用fluid,默认动态图开发模式。

"pad_id = word_dict['<pad>']\r\n",
"embed_dim = 32 # Embedding size for each token\r\n",
"num_heads = 2 # Number of attention heads\r\n",
"ff_dim = 32 # Hidden layer size in feed forward network inside transformer\r\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ff_dim变量命名不是很清晰。

" x = self.drop2(x)\r\n",
" x = self.soft(x)\r\n",
" return x\r\n",
"# class MyNet(paddle.nn.Layer):\r\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处注释可删除。

},
"source": [
"可以看到经过两轮的迭代训练,可以达到85%左右的准确率,当然你也可以通过调整参数、更改优化方式等等来进一步提升性能。"
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可使用model.predict进行预测,打印出句子,预测标签和实际标签,这样比较直观。

@YinHang2515
Copy link
Author

根据要求进行了相应的修改,并已同步更新至AIStudio

Copy link

@chenxiaozeng chenxiaozeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 suggestions.

"class PointWiseFeedForwardNetwork(nn.Layer):\r\n",
" def __init__(self, embed_dim, feed_dim):\r\n",
" super(PointWiseFeedForwardNetwork, self).__init__()\r\n",
" self.linear1 = pd.fluid.dygraph.Linear(embed_dim, feed_dim, act='relu')\r\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多处fluid需要改成nn

" loss=nn.CrossEntropyLoss())\r\n",
"\r\n",
"# 模型训练\r\n",
"model.fit(train_loader,\r\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练完成之后,可以调用model.predict()测试下模型在test数据集上的表现。

@@ -0,0 +1 @@

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file, to delete?

@YinHang2515
Copy link
Author

根据要求进行了相应的修改,并已同步更新至AIStudio

Copy link

@chenxiaozeng chenxiaozeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

},
"outputs": [],
"source": [
"class TransformerBlock(nn.Layer):\r\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paddle中已经提供了Transformer的相关API https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/nn/layer/transformer/TransformerEncoder_cn.html#transformerencoder ,如果只是为了使用而不是要说明这些具体实现的话,可否直接使用这些API呢

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了上述问题外,还有两处需要注意下:
1、2.0已经发布了,麻烦更新到2.0版本;
2、看预测的效果不是特别好,可以再优化一下网络
感谢~

},
"outputs": [],
"source": [
"import paddle as pd\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import paddle

"source": [
"import paddle as pd\n",
"import paddle.nn as nn\n",
"import paddle.nn.functional as func\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不推荐这么写

"train_dataset = IMDBDataset(train_sents, train_labels)\r\n",
"test_dataset = IMDBDataset(test_sents, test_labels)\r\n",
"\r\n",
"train_loader = pd.io.DataLoader(train_dataset, places=pd.CPUPlace(), return_list=True,\r\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

places=pd.CPUPlace() 可以删除

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants