Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cdpm model #261

Closed
wants to merge 9 commits into from
Closed

add cdpm model #261

wants to merge 9 commits into from

Conversation

xiaoiker
Copy link

No description provided.

@xiaoiker xiaoiker linked an issue Feb 15, 2023 that may be closed by this pull request
@xiaoiker xiaoiker marked this pull request as ready for review March 27, 2023 23:35
@Warvito Warvito added the need reviewer This PR need a reviewer label Apr 8, 2023
Copy link
Collaborator

@marksgraham marksgraham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xiaoiker

Thanks for this PR and sorry it took so long to get round to reviewing it. This is a partial review - I'm still running the notebook but it won't be done for a while, so I'll have a bunch more comments tomorrow.

Given the amount of type elapsed between PR and review, I think it might be worth rebasing on main and checking that doesn't break anything?

@@ -10,3 +10,4 @@
# limitations under the License.

from .vector_quantizer import EMAQuantizer, VectorQuantizer
from .vector_quantizer import EMAQuantizer, VectorQuantizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean to import from .RPE here, not duplicate the vector_quantizer import

Comment on lines +16 to +17
# from monai.networks.blocks import MLPBlock
# from monai.networks.layers import Act
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove unused imports

from generative.inferers import DiffusionInferer

# TODO: Add right import reference after deployed
from generative.networks.nets.cdpm import UNet_2Plus1_Model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add the inport to the __init__ under nets, and then import as:
from generative.networks.nets import UNet_2Plus1_Model


# %%
directory = os.environ.get("MONAI_DATA_DIRECTORY")
directory = '/home/Nobias/data/MONAI_DATA_DIRECTORY'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove your user-specific dir

Comment on lines +134 to +141
# # %%
# plt.subplots(1, 4, figsize=(10, 6))
# for i in range(4):
# plt.subplot(1, 4, i + 1)
# plt.imshow(train_ds[i * 20]["image"][0, :, :, 15].detach().cpu(), vmin=0, vmax=1, cmap="gray")
# plt.axis("off")
# plt.tight_layout()
# plt.show()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove these changes made to the 3d_ddpm tutorial

latent_row[self.sample_some_indices(max_indices=N, T=T)] = 1.
while True:
mask = obs_row if torch.rand(()) < 0.5 else latent_row
indices = torch.tensor(self.sample_some_indices(max_indices=N, T=T))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had to move the indices to the same device as batch1 using .to(batch1.device) or I got RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

return x


class UNet_2Plus1_Model(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this better be called CDPM?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also have here docstring description of what this model is and what it's for, explaining the conditional component with possible use examples. Perhaps even ConditionalDPM as a name to make it clearer what this is.

"""
Attention with slice relative position encoding by Wu et al. (https://arxiv.org/abs/2107.14222) and the official implementation
that can be found at https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py.
Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please leave a space between the preamble and the Args for all docstrings

class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a space and add Args: before listing the arguments in all docstrings

# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add from __future__ import annotations to be able to use | in type hints in older version of python

Comment on lines +112 to +117
for b in range(B):
for d in range(D):
for h in range(H):
for i in range(T):
for j in range(T):
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for b in range(B):
for d in range(D):
for h in range(H):
for i in range(T):
for j in range(T):
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h])
for b, d, h, i, j in np.ndindex(B, D, H, T, T):
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h])

class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change Sphinx style tags to google style tags, here and elsewhere in the code.

@marksgraham
Copy link
Collaborator

I'm finding the tutorial samples blank images after training for 70 epochs - is this expected?

image

@xiaoiker
Copy link
Author

xiaoiker commented Jul 3, 2023

I'm finding the tutorial samples blank images after training for 70 epochs - is this expected?

image

Oh really sorry for the late reply! Even myself thought forgot about this!

For the blank images, yes they are. Basically, cDPM needs much more training steps as it is trying to 3D information from 2D model.

I will fix the above issues while I am free. Thanks very much for your help.

@marksgraham
Copy link
Collaborator

I'm assuming there is no longer a plan to work on this and closing for now - please do reopen if you do want to finish the PR, @xiaoiker

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need reviewer This PR need a reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add CDPM model
4 participants