跳至主要內容
torchrec cn_embedding模块设计方案

1.目标

  • 实现对torchrec embedding/embeddingbag底层存储结构的替换,用于实现embedding table动态扩容,贴合实际业务场景
  • 训练DLRM网络性能目标:吞吐550万,精度0.8025

2.替换内容

  • 使用collection替换torch.nn.EmbeddingBag底层数据结构
  • 需要实现:
    • embeddingbag随机初始化
    • embedding的池化操作:sum(主要)/mean/max
    • forward():向embeddingbag中传入一个/一批待查找key(Tensor),embeddingbag返回查询的embedding vector(一个/多个embedding做池化后得到的Tensor)
    • 实现cpu和cncard端embedding table(attr: device, embeddingbag.to() : "meta"/"cpu"/"cncard")
    • 支持动态扩容,可新增相关类方法

BradZhone大约 2 分钟推荐系统CNCardEmbedding