Replies: 1 comment
-
Hi @Requiem8 , See here |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Should it be input_.dtype instead of input_.dtype()
From the 9th code cell of the extending_functionality tutorial, using google collab T4GPU:
trainer, model = build_trainer(rpu_config, log="two_pass")
print(model)
fit_model(trainer, model)
plot_loss(trainer, "FP update with non-ideal two-pass forward.")
Output:
LitAnalogModel(
(analog_model): AnalogSequential(
(0): AnalogLinearMapped(
in_features=784, out_features=256, bias=True, TwoPassTorchInferenceRPUConfig
(analog_module): TileModuleArray(
(array): ModuleList(
(0-1): 2 x ModuleList(
(0): TorchInferenceTile(
(tile): TwoPassTorchSimulatorTile(256, 392, cpu)
)
)
)
)
)
(1): Sigmoid()
(2): AnalogLinearMapped(
in_features=256, out_features=128, bias=True, TwoPassTorchInferenceRPUConfig
(analog_module): TorchInferenceTile(
(tile): TwoPassTorchSimulatorTile(128, 256, cpu)
)
)
(3): Sigmoid()
(4): AnalogLinearMapped(
in_features=128, out_features=10, bias=True, TwoPassTorchInferenceRPUConfig
(analog_module): TorchInferenceTile(
(tile): TwoPassTorchSimulatorTile(10, 128, cpu)
)
)
(5): LogSoftmax(dim=1)
)
)
Sanity Checking DataLoader 0: 0%
0/2 [00:00<?, ?it/s]
TypeError Traceback (most recent call last)
in <cell line: 3>()
1 trainer, model = build_trainer(rpu_config, log="two_pass")
2 print(model)
----> 3 fit_model(trainer, model)
4 plot_loss(trainer, "FP update with non-ideal two-pass forward.")
27 frames
/usr/local/lib/python3.10/dist-packages/aihwkit/simulator/tiles/analog_mvm.py in matmul(cls, weight, input_, io_pars, trans, is_test, **fwd_pars)
92 ):
93 # - Shortcut, output would be all zeros
---> 94 return zeros(size=out_size, device=input_.device, dtype=input_.dtype())
95
96 if isinstance(nm_scale_values, Tensor):
TypeError: 'torch.dtype' object is not callable
Beta Was this translation helpful? Give feedback.
All reactions