-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement lu_unpack in jax #8262
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
qihqi
reviewed
Oct 15, 2024
P = P.at[..., col_idx, row_idx].set(pi) | ||
print("debug: p2:", P) | ||
else: | ||
_pivots = LU_pivots - 1 # pivots are offset by 1 in jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this branch looks more correct according to this explanation:
Pivot Matrix Representation
PyTorch uses a compact representation for the permutation matrix. Instead of storing the full matrix, it stores a 1-indexed vector called pivots.
Each element pivots[i] indicates the row that the i-th row was swapped with during the LU decomposition process.
Importantly, it uses 1-indexing, meaning pivots[i] = j signifies that the i-th row was swapped with the (j-1)-th row.
Your Example: [2, 2] -> [[0, 1], [1, 0]]
Let's analyze your example step-by-step:
Input pivots = [2, 2]: This vector tells us:
In the 1st step, the 1st row was swapped with the (2-1) = 1st row (essentially, no swap).
In the 2nd step, the 2nd row was swapped with the (2-1) = 1st row.
Constructing the Permutation Matrix: We start with an identity matrix and apply the swaps indicated by the pivots vector:
Initial Identity Matrix:
[[1, 0],
[0, 1]]
Step 1 (no swap): The matrix remains unchanged.
Step 2 (swap rows 2 and 1):
[[0, 1],
[1, 0]]
qihqi
reviewed
Oct 15, 2024
indices[i], indices[_pivots[i]] = indices[_pivots[i]], indices[i] | ||
#print("[debug]: i, pivot[i], indices:", i, _pivots[i], indices) | ||
P = P[jnp.array(indices)] | ||
P = jnp.transpose(P) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why transpose here?
barney-s
force-pushed
the
lu_unpack
branch
10 times, most recently
from
October 18, 2024 06:25
9a292c5
to
be69b70
Compare
- For Lower and Upper matrices, use jnp.tril, jnp.triu - Shape both triangle matrices and add 1's to the lower triangle to match the pytorch behavior - For 2D permutation matrix start with an identity matrix and mutate it based on pivots - start with sequential indices and apply the pivot operations to it - finally use the pivoted indices to index the identity matrix to generate the final permutation matrix for that pivot. - For 2D inputs and 1D pivot the above logic would work - For >=3d inputs, we first reshape the inputs to become 3D and then call vmap along the first dim with the 2d logic for each 2d matrix
qihqi
approved these changes
Oct 18, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Implement lu_unpack in jax
match the pytorch behavior
based on pivots
generate the final permutation matrix for that pivot.
call vmap along the first dim with the 2d logic for each 2d matrix
Fixes: #7507