1.目标
- 实现对torchrec embedding/embeddingbag底层存储结构的替换,用于实现embedding table动态扩容,贴合实际业务场景
- 训练DLRM网络性能目标:吞吐550万,精度0.8025
2.替换内容
- 使用collection替换torch.nn.EmbeddingBag底层数据结构
- 需要实现:
- embeddingbag随机初始化
- embedding的池化操作:sum(主要)/mean/max
- forward():向embeddingbag中传入一个/一批待查找key(Tensor),embeddingbag返回查询的embedding vector(一个/多个embedding做池化后得到的Tensor)
- 实现cpu和cncard端embedding table(attr: device, embeddingbag.to() : "meta"/"cpu"/"cncard")
- 支持动态扩容,可新增相关类方法
大约 2 分钟