We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import torch from torchrec import KeyedJaggedTensor from torchrec import EmbeddingBagConfig,EmbeddingConfig from torchrec import EmbeddingBagCollection,EmbeddingCollection kt = KeyedJaggedTensor( keys=['t1', 't2'], values=torch.tensor([0,0,0,0,2]), lengths=torch.tensor([1,1,1,1,0,1], dtype=torch.int64), ) kt2 = KeyedJaggedTensor( keys=['t1', 't2'], values=torch.tensor([0,0,2]), lengths=torch.tensor([1,1,0,1], dtype=torch.int64), stride_per_key_per_rank=[[1], [3]], inverse_indices=(['t1', 't2'], torch.tensor([[0,0,0], [0,1,2]])) ) eb_configs = [ EmbeddingBagConfig( num_embeddings=100, embedding_dim=16, name='e1', feature_names=['t1'] ), EmbeddingBagConfig( num_embeddings=100, embedding_dim=16, name='e2', feature_names=['t2'] ) ] ebc = EmbeddingBagCollection(eb_configs) print(ebc(kt)['t1']) print(ebc(kt2)['t1']) eb_configs = [ EmbeddingConfig( num_embeddings=100, embedding_dim=16, name='e1', feature_names=['t1'] ), EmbeddingConfig( num_embeddings=100, embedding_dim=16, name='e2', feature_names=['t2'] ) ] ebc = EmbeddingCollection(eb_configs) print(ebc(kt)["t1"].lengths().size()) print(ebc(kt2)["t1"].lengths().size())
结果: EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1
The text was updated successfully, but these errors were encountered:
ccn @joshuadeng
Sorry, something went wrong.
hi @yjjinjie, currently EmbeddingCollection does not support variable batch size per feature here. This work is being planned so stay tuned.
No branches or pull requests
结果:
EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1
The text was updated successfully, but these errors were encountered: