-
Notifications
You must be signed in to change notification settings - Fork 476
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- For Lower and Upper matrices, use jnp.tril, jnp.triu - Shape both triangle matrices and add 1's to the lower triangle to match the pytorch behavior - For 2D permutation matrix start with an identity matrix and mutate it based on pivots - start with sequential indices and apply the pivot operations to it - finally use the pivoted indices to index the identity matrix to generate the final permutation matrix for that pivot. - For 2D inputs and 1D pivot the above logic would work - For >=3d inputs, we first reshape the inputs to become 3D and then call vmap along the first dim with the 2d logic for each 2d matrix
- Loading branch information
Showing
2 changed files
with
102 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters