torchrec cn_embedding模块设计方案
大约 2 分钟
torchrec cn_embedding模块设计方案
1.目标
- 实现对torchrec embedding/embeddingbag底层存储结构的替换,用于实现embedding table动态扩容,贴合实际业务场景
- 训练DLRM网络性能目标:吞吐550万,精度0.8025
2.替换内容
- 使用collection替换torch.nn.EmbeddingBag底层数据结构
- 需要实现:
- embeddingbag随机初始化
- embedding的池化操作:sum(主要)/mean/max
- forward():向embeddingbag中传入一个/一批待查找key(Tensor),embeddingbag返回查询的embedding vector(一个/多个embedding做池化后得到的Tensor)
- 实现cpu和cncard端embedding table(attr: device, embeddingbag.to() : "meta"/"cpu"/"cncard")
- 支持动态扩容,可新增相关类方法
3.替换方案
编译collection为动态库,使用c++ 开发cn_embedding数据结构及对应方法,对齐原生embedingbag接口,新增动态扩容方法
使用pytorch官方接口自定义cn_embedding算子供python端调用
将torchrec/torchrec/modules/embedding_modules.py中的nn.Embedding(Bag)替换为CNEmbedding(Bag)
需要实现embedding的init(创建emb时,需要insert表条目,设定初始值)&lookup(前向查表)&update(反向更新表)
具体实现思路:
- c++: 使用collection封装为cn_embedding模块,提供CNEmbedding_C类接口,提供以下类方法
- init:申请空间,从0至emb table size逐个插入k-v对,emb vector使用随机值
- lookup:查表,用于前向
- update:更新表,用于反向
- python:实现CNEmbedding(bag)的自定义层,继承nn.Module
- 主要成员:
- emb_key:记录待查询key_id,用于后续更新embedding table
- emb_vec:记录查询得到的emb vector,作为可训练参数,用于后续更新embedding teable
- cn_emb_tbl:使用CNEmbedding_C接口创建的embedding table
- forward:
- 从cn_emb_tbl中查表得到emb vectors,(若是embeddingbag,还需要对它们做pooling:sum/mean/max) 赋值并返回emb_vec
- backward:
- 将返回的梯度更新到emb_vec,再cn_emb_tbl.update(emb_key, emb_vec)更新回embedding table
- 主要成员:
- c++: 使用collection封装为cn_embedding模块,提供CNEmbedding_C类接口,提供以下类方法
实现方案:
- 总体:使用pytorch扩展库的方式实现整个
cn_embedding
模块,c++端使用torch script、collection实现基本的嵌入表创建查找更新等操作,再封装为pytorch custom class向python端暴露接口;python端实现CNEmbedding和CNEmbeddingBag层,用于替换pytorch原生torch.nn.Embedding 和 torch.nn.EmbeddingBag。最后,再将Torchrec的数据结构EmbeddingBagCollection的嵌入层实现替换为CNEmbedding和CNEmbeddingBag
- 总体:使用pytorch扩展库的方式实现整个
4.排期
- [x] 6.5 调研nn.EmbeddingBag各参数的作用,底层实现方法,如何做insert,对齐python接口,调研开发pytorch c++扩展库方法
- [x] 6.6-6.7 实现CNEmbedding创建与初始化
- [x] 6.8-6.9 实现lookup&update
- [x] 6.12-6.14 python端api开发
- [x] 6.15-6.16 unittest & dlrm网络测试(需适配框架后才能测试)