-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cdpm model #261
add cdpm model #261
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @xiaoiker
Thanks for this PR and sorry it took so long to get round to reviewing it. This is a partial review - I'm still running the notebook but it won't be done for a while, so I'll have a bunch more comments tomorrow.
Given the amount of type elapsed between PR and review, I think it might be worth rebasing on main and checking that doesn't break anything?
@@ -10,3 +10,4 @@ | |||
# limitations under the License. | |||
|
|||
from .vector_quantizer import EMAQuantizer, VectorQuantizer | |||
from .vector_quantizer import EMAQuantizer, VectorQuantizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mean to import from .RPE here, not duplicate the vector_quantizer import
# from monai.networks.blocks import MLPBlock | ||
# from monai.networks.layers import Act |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove unused imports
from generative.inferers import DiffusionInferer | ||
|
||
# TODO: Add right import reference after deployed | ||
from generative.networks.nets.cdpm import UNet_2Plus1_Model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add the inport to the __init__
under nets, and then import as:
from generative.networks.nets import UNet_2Plus1_Model
|
||
# %% | ||
directory = os.environ.get("MONAI_DATA_DIRECTORY") | ||
directory = '/home/Nobias/data/MONAI_DATA_DIRECTORY' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove your user-specific dir
# # %% | ||
# plt.subplots(1, 4, figsize=(10, 6)) | ||
# for i in range(4): | ||
# plt.subplot(1, 4, i + 1) | ||
# plt.imshow(train_ds[i * 20]["image"][0, :, :, 15].detach().cpu(), vmin=0, vmax=1, cmap="gray") | ||
# plt.axis("off") | ||
# plt.tight_layout() | ||
# plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove these changes made to the 3d_ddpm tutorial
latent_row[self.sample_some_indices(max_indices=N, T=T)] = 1. | ||
while True: | ||
mask = obs_row if torch.rand(()) < 0.5 else latent_row | ||
indices = torch.tensor(self.sample_some_indices(max_indices=N, T=T)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i had to move the indices to the same device as batch1 using .to(batch1.device)
or I got RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
return x | ||
|
||
|
||
class UNet_2Plus1_Model(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this better be called CDPM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also have here docstring description of what this model is and what it's for, explaining the conditional component with possible use examples. Perhaps even ConditionalDPM
as a name to make it clearer what this is.
""" | ||
Attention with slice relative position encoding by Wu et al. (https://arxiv.org/abs/2107.14222) and the official implementation | ||
that can be found at https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py. | ||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please leave a space between the preamble and the Args for all docstrings
class Downsample(nn.Module): | ||
""" | ||
A downsampling layer with an optional convolution. | ||
:param channels: channels in the inputs and outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave a space and add Args:
before listing the arguments in all docstrings
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add from __future__ import annotations
to be able to use | in type hints in older version of python
for b in range(B): | ||
for d in range(D): | ||
for h in range(H): | ||
for i in range(T): | ||
for j in range(T): | ||
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for b in range(B): | |
for d in range(D): | |
for h in range(H): | |
for i in range(T): | |
for j in range(T): | |
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h]) | |
for b, d, h, i, j in np.ndindex(B, D, H, T, T): | |
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h]) | |
class Upsample(nn.Module): | ||
""" | ||
An upsampling layer with an optional convolution. | ||
:param channels: channels in the inputs and outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change Sphinx style tags to google style tags, here and elsewhere in the code.
Oh really sorry for the late reply! Even myself thought forgot about this! For the blank images, yes they are. Basically, cDPM needs much more training steps as it is trying to 3D information from 2D model. I will fix the above issues while I am free. Thanks very much for your help. |
I'm assuming there is no longer a plan to work on this and closing for now - please do reopen if you do want to finish the PR, @xiaoiker |
No description provided.