-
Notifications
You must be signed in to change notification settings - Fork 1.3k
#使用Transformer进行文本分类#代码提交 #937
base: develop
Are you sure you want to change the base?
Conversation
"source": [ | ||
"import paddle\n", | ||
"import paddle.nn as nn\n", | ||
"import paddle.fluid.dygraph as dg\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paddle2.0不建议使用fluid,默认动态图开发模式。
"pad_id = word_dict['<pad>']\r\n", | ||
"embed_dim = 32 # Embedding size for each token\r\n", | ||
"num_heads = 2 # Number of attention heads\r\n", | ||
"ff_dim = 32 # Hidden layer size in feed forward network inside transformer\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ff_dim变量命名不是很清晰。
" x = self.drop2(x)\r\n", | ||
" x = self.soft(x)\r\n", | ||
" return x\r\n", | ||
"# class MyNet(paddle.nn.Layer):\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处注释可删除。
}, | ||
"source": [ | ||
"可以看到经过两轮的迭代训练,可以达到85%左右的准确率,当然你也可以通过调整参数、更改优化方式等等来进一步提升性能。" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可使用model.predict进行预测,打印出句子,预测标签和实际标签,这样比较直观。
根据要求进行了相应的修改,并已同步更新至AIStudio
根据要求进行了相应的修改,并已同步更新至AIStudio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 suggestions.
"class PointWiseFeedForwardNetwork(nn.Layer):\r\n", | ||
" def __init__(self, embed_dim, feed_dim):\r\n", | ||
" super(PointWiseFeedForwardNetwork, self).__init__()\r\n", | ||
" self.linear1 = pd.fluid.dygraph.Linear(embed_dim, feed_dim, act='relu')\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多处fluid需要改成nn
" loss=nn.CrossEntropyLoss())\r\n", | ||
"\r\n", | ||
"# 模型训练\r\n", | ||
"model.fit(train_loader,\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
训练完成之后,可以调用model.predict()测试下模型在test数据集上的表现。
@@ -0,0 +1 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file, to delete?
根据要求进行了相应的修改,并已同步更新至AIStudio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
}, | ||
"outputs": [], | ||
"source": [ | ||
"class TransformerBlock(nn.Layer):\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paddle中已经提供了Transformer的相关API https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/nn/layer/transformer/TransformerEncoder_cn.html#transformerencoder ,如果只是为了使用而不是要说明这些具体实现的话,可否直接使用这些API呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除了上述问题外,还有两处需要注意下:
1、2.0已经发布了,麻烦更新到2.0版本;
2、看预测的效果不是特别好,可以再优化一下网络
感谢~
}, | ||
"outputs": [], | ||
"source": [ | ||
"import paddle as pd\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import paddle
"source": [ | ||
"import paddle as pd\n", | ||
"import paddle.nn as nn\n", | ||
"import paddle.nn.functional as func\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时不推荐这么写
"train_dataset = IMDBDataset(train_sents, train_labels)\r\n", | ||
"test_dataset = IMDBDataset(test_sents, test_labels)\r\n", | ||
"\r\n", | ||
"train_loader = pd.io.DataLoader(train_dataset, places=pd.CPUPlace(), return_list=True,\r\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
places=pd.CPUPlace() 可以删除
项目地址:https://aistudio.baidu.com/aistudio/projectdetail/1247954