PyTorch入门(八)Embedding层
词向量实现在PyTorch中对应于Embedding层,其实现代码的源码函数(PyTorch的版本为2.0)如下:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None)
该函数随机会生成了一个向量,可以把它看作一个词向量查询表,其size为[num_embeddings,embedding_dim] 。其中num_embeddings是查询表的大小,embedding_dim是每个查询向量的维度。
函数参数解释:
- num_embeddings: int, 查询表的大小
- embedding_dim: int, 每个查询向量的维度
- padding_idx: int, 填充id
- max_norm: float, 最大范数,每个范数超过max_norm的embedding向量会重新规范化至其范数为max_norm
- norm_type: float, p范数(p=norm_type),默认值为2
需要注意的是,查询的下标向量的数据类型必须使Long,即Int64.
让我们来看几个Embedding层的使用例子。
- 例子1:
1 |
|
输出:
1 |
|
- 例子2:
1 |
|
输出:
1 |
|
从中我们可以发现,Embedding层中padding_idx对应的向量为零向量。
参考文献
- Embedding source code: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。
PyTorch入门(八)Embedding层
https://percent4.github.io/PyTorch入门(八)Embedding层/