diff --git a/neuron_morphology/transforms/affine_transform.py b/neuron_morphology/transforms/affine_transform.py index 71506c56..8d11db55 100644 --- a/neuron_morphology/transforms/affine_transform.py +++ b/neuron_morphology/transforms/affine_transform.py @@ -160,14 +160,19 @@ def _get_scaling_factor(self) -> float: determinant = np.linalg.det(self.affine) return np.power(abs(determinant), 1.0 / 3.0) - def transform_morphology(self, morphology: Morphology, - clone: bool = False) -> Morphology: + def transform_morphology(self, + morphology: Morphology, + clone: bool = False, + scale_radius: bool = True, + ) -> Morphology: """ Apply this transform to all nodes in a morphology. Parameters ---------- morphology: a Morphology loaded from an swc file + clone: make a new object if True + scale_radius: apply radius scaling if True Returns ------- @@ -176,7 +181,10 @@ def transform_morphology(self, morphology: Morphology, if clone: morphology = morphology.clone() - scaling_factor = self._get_scaling_factor() + if scale_radius: + scaling_factor = self._get_scaling_factor() + else: + scaling_factor = 1 for node in morphology.nodes(): coordinates = np.array((node['x'], node['y'], node['z']), diff --git a/neuron_morphology/transforms/scale_correction/compute_scale_correction.py b/neuron_morphology/transforms/scale_correction/compute_scale_correction.py index 31ed9bc3..a1df79fc 100644 --- a/neuron_morphology/transforms/scale_correction/compute_scale_correction.py +++ b/neuron_morphology/transforms/scale_correction/compute_scale_correction.py @@ -94,7 +94,8 @@ def run_scale_correction( [0, 0, scale_correction]] ) at = aff.AffineTransform(scale_transform) - morphology_scaled = at.transform_morphology(morphology) + morphology_scaled = at.transform_morphology(morphology, + scale_radius=False) return { "morphology_scaled": morphology_scaled, diff --git a/tests/transforms/test_scale_correction.py b/tests/transforms/test_scale_correction.py index 9f40ffcf..cddab548 100644 --- a/tests/transforms/test_scale_correction.py +++ b/tests/transforms/test_scale_correction.py @@ -18,8 +18,8 @@ def setUp(self): self.morphology = ( MorphologyBuilder() - .root(1, 2, 3) - .axon(0, 2, 100) + .root(1, 2, 3, radius=1.3) + .axon(0, 2, 100, radius=0.4) .build() ) self.max_morphology = ( @@ -79,12 +79,14 @@ def test_scale_correction_end_to_end(self): self.assertAlmostEqual(outputs['scale_correction'], 2.0) aff_t = AffineTransform.from_dict(outputs['scale_transform']) - morph_t = aff_t.transform_morphology(self.morphology) + morph_t = aff_t.transform_morphology(self.morphology, + scale_radius=False) axon = morph_t.node_by_id(1) self.assertAlmostEqual(axon['x'], 0) self.assertAlmostEqual(axon['y'], 2) self.assertAlmostEqual(axon['z'], 200) + self.assertAlmostEqual(axon['radius'], 0.4) def test_run_scale_correction(self):