跳至主要內容

torchrec cn_embedding模块设计方案

BradZhone大约 2 分钟推荐系统CNCardEmbedding

torchrec cn_embedding模块设计方案

1.目标

  • 实现对torchrec embedding/embeddingbag底层存储结构的替换,用于实现embedding table动态扩容,贴合实际业务场景
  • 训练DLRM网络性能目标:吞吐550万,精度0.8025

2.替换内容

  • 使用collection替换torch.nn.EmbeddingBag底层数据结构
  • 需要实现:
    • embeddingbag随机初始化
    • embedding的池化操作:sum(主要)/mean/max
    • forward():向embeddingbag中传入一个/一批待查找key(Tensor),embeddingbag返回查询的embedding vector(一个/多个embedding做池化后得到的Tensor)
    • 实现cpu和cncard端embedding table(attr: device, embeddingbag.to() : "meta"/"cpu"/"cncard")
    • 支持动态扩容,可新增相关类方法

3.替换方案

  • 编译collection为动态库,使用c++ 开发cn_embedding数据结构及对应方法,对齐原生embedingbag接口,新增动态扩容方法

  • 使用pytorch官方接口自定义cn_embedding算子供python端调用

  • 将torchrec/torchrec/modules/embedding_modules.py中的nn.Embedding(Bag)替换为CNEmbedding(Bag)

  • 需要实现embedding的init(创建emb时,需要insert表条目,设定初始值)&lookup(前向查表)&update(反向更新表)

  • 具体实现思路:

    • c++: 使用collection封装为cn_embedding模块,提供CNEmbedding_C类接口,提供以下类方法
      • init:申请空间,从0至emb table size逐个插入k-v对,emb vector使用随机值
      • lookup:查表,用于前向
      • update:更新表,用于反向
    • python:实现CNEmbedding(bag)的自定义层,继承nn.Module
      • 主要成员:
        • emb_key:记录待查询key_id,用于后续更新embedding table
        • emb_vec:记录查询得到的emb vector,作为可训练参数,用于后续更新embedding teable
        • cn_emb_tbl:使用CNEmbedding_C接口创建的embedding table
      • forward:
        • 从cn_emb_tbl中查表得到emb vectors,(若是embeddingbag,还需要对它们做pooling:sum/mean/max) 赋值并返回emb_vec
      • backward:
        • 将返回的梯度更新到emb_vec,再cn_emb_tbl.update(emb_key, emb_vec)更新回embedding table
  • 实现方案:

    • 总体:使用pytorch扩展库的方式实现整个cn_embedding模块,c++端使用torch script、collection实现基本的嵌入表创建查找更新等操作,再封装为pytorch custom class向python端暴露接口;python端实现CNEmbedding和CNEmbeddingBag层,用于替换pytorch原生torch.nn.Embedding 和 torch.nn.EmbeddingBag。最后,再将Torchrec的数据结构EmbeddingBagCollection的嵌入层实现替换为CNEmbedding和CNEmbeddingBag

4.排期

  • [x] 6.5 调研nn.EmbeddingBag各参数的作用,底层实现方法,如何做insert,对齐python接口,调研开发pytorch c++扩展库方法
  • [x] 6.6-6.7 实现CNEmbedding创建与初始化
  • [x] 6.8-6.9 实现lookup&update
  • [x] 6.12-6.14 python端api开发
  • [x] 6.15-6.16 unittest & dlrm网络测试(需适配框架后才能测试)