From 274d883f92e6da548ccdeee704ea64e34bf2ff6e Mon Sep 17 00:00:00 2001 From: khaotik Date: Sun, 8 Aug 2021 13:57:02 +0800 Subject: [PATCH 1/5] Let SymbolBlock copy `lr_mult` `wd_mult` `init` attributes from symbol --- python/mxnet/gluon/block.py | 45 ++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index cff346b9f4aa..3c3ea1aec4d6 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1785,11 +1785,11 @@ def __init__(self, outputs, inputs, params=None): syms, self._in_format = _flatten(inputs, "input") out, self._out_format = _flatten(outputs, "output") - input_names = set() + input_name_set = set() for i in syms: assert len(i.get_internals().list_outputs()) == 1, \ "Input symbols must be variable, but %s is an output of operators"%str(i) - input_names.add(i.name) + input_name_set.add(i.name) # check if any symbol is row_sparse row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] @@ -1806,35 +1806,50 @@ def __init__(self, outputs, inputs, params=None): # Infer type of parameters. Without this, every parameter will be created with # default type i.e., fp32 - arg_params = out.list_arguments() - aux_params = out.list_auxiliary_states() + arg_param_li = out.list_arguments() + aux_param_li = out.list_auxiliary_states() + input_di = out.get_inputs() - arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) + arg_type_li, aux_type_li = _infer_param_types(syms, out, arg_param_li, aux_param_li) if params is None: params = {} - unused_params = set(params.keys()) - set(arg_params) - set(aux_params) + unused_params = set(params.keys()) - set(arg_param_li) - set(aux_param_li) if len(unused_params) > 0: raise ValueError('{} params are unused by the model.'.format(unused_params)) self._reg_params = params - - for i, arg in enumerate(arg_params): + def _extract_initializer(_s_): + _initer_json_ = _s_.list_attr().get('__init__') + if _initer_json_ is None: + return None + _type_str_, _args_di_ = json.loads(_initer_json_) + return initializer.create(_type_str_, **_args_di_) + + for i, arg in enumerate(arg_param_li): if arg in self._reg_params: - self._reg_params[arg]._check_and_setattr(allow_deferred_init=True, dtype=arg_types[i]) + self._reg_params[arg]._check_and_setattr(allow_deferred_init=True, dtype=arg_type_li[i]) if self._reg_params[arg]._var is None: self._reg_params[arg]._var_name = arg - elif arg not in input_names: - self._reg_params[arg] = Parameter(name=arg, allow_deferred_init=True, dtype=arg_types[i]) + elif arg not in input_name_set: + sym_ = input_di[arg] + sym_attr = sym_.list_attr() + self._reg_params[arg] = Parameter( + name=arg, + init=_extract_initializer(sym_), + lr_mult=float(sym_attr.get('__lr_mult__', 1.0)), + wd_mult=float(sym_attr.get('__wd_mult__', 1.0)), + allow_deferred_init=True, + dtype=arg_type_li[i]) self._reg_params[arg]._var_name = arg - for i, aux in enumerate(aux_params): + for i, aux in enumerate(aux_param_li): if aux in self._reg_params: self._reg_params[aux]._check_and_setattr(grad_req='null', allow_deferred_init=True, - dtype=aux_types[i]) + dtype=aux_type_li[i]) if self._reg_params[aux]._var is None: self._reg_params[aux]._var_name = aux - elif aux not in input_names: + elif aux not in input_name_set: self._reg_params[aux] = Parameter(name=aux, grad_req='null', - allow_deferred_init=True, dtype=aux_types[i]) + allow_deferred_init=True, dtype=aux_type_li[i]) self._reg_params[aux]._var_name = aux self._cached_graph = syms, out From 8544c05592d26c30e67a05ec30169fa379390f69 Mon Sep 17 00:00:00 2001 From: khaotik Date: Sun, 29 May 2022 09:36:25 +0800 Subject: [PATCH 2/5] add test --- tests/python/unittest/test_gluon.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 33fd48a256a6..d3773e16274e 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -224,6 +224,24 @@ def test_basic(): model.setattr('grad_req', 'write') assert list(model.collect_params().values())[0]._grad is not None +@use_np +def test_symbol_block_init(): + DTYPE = mx.np.float32 + LR_MULT, WD_MULT = 0.555, 0.444 + svar = mx.symbol.var + s_x = svar('x', shape=(1,256,), dtype=DTYPE) + s_w = svar('W', shape=(256,192), dtype=DTYPE, lr_mult=LR_MULT, wd_mult=WD_MULT) + s_b = svar('b', shape=(1,192,), dtype=DTYPE, init=mx.init.Zero()) + s_y = mx.symbol.linalg.gemm(s_x, s_w, s_b) + + fn = mx.gluon.SymbolBlock([s_y], [s_x]) + fn.initialize() + fn.forward(mx.nd.random_uniform(-1., 1., shape=(1,256), dtype=DTYPE)) + param_di = fn.collect_params() + v_w, v_b = param_di['W'], param_di['b'] + assert v_w.lr_mult == LR_MULT + assert v_w.wd_mult == WD_MULT + assert not v_b.data().asnumpy().any() def test_sparse_symbol_block(): data = mx.sym.var('data') From 50a454710df72d2f9496521ca1fe144c191cd26a Mon Sep 17 00:00:00 2001 From: khaotik Date: Sun, 29 May 2022 19:28:07 +0800 Subject: [PATCH 3/5] fix tests --- python/mxnet/gluon/block.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 3c3ea1aec4d6..42e3d66a5e70 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1822,7 +1822,10 @@ def _extract_initializer(_s_): _initer_json_ = _s_.list_attr().get('__init__') if _initer_json_ is None: return None - _type_str_, _args_di_ = json.loads(_initer_json_) + try: + _type_str_, _args_di_ = json.loads(_initer_json_) + except json.JSONDecodeError as e: + _type_str_, _args_di_ = _initer_json_, {} return initializer.create(_type_str_, **_args_di_) for i, arg in enumerate(arg_param_li): From d66de4faed4e45102cedf1eda42d8138252a0a47 Mon Sep 17 00:00:00 2001 From: khaotik Date: Mon, 30 May 2022 08:50:41 +0800 Subject: [PATCH 4/5] use current_device() on test --- tests/python/unittest/test_gluon.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d3773e16274e..a33babd2f20b 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -236,7 +236,9 @@ def test_symbol_block_init(): fn = mx.gluon.SymbolBlock([s_y], [s_x]) fn.initialize() - fn.forward(mx.nd.random_uniform(-1., 1., shape=(1,256), dtype=DTYPE)) + v_x = mx.nd.random_uniform(-1., 1., shape=(1,256), dtype=DTYPE, + device=mx.device.current_device()) + fn.forward(v_x) param_di = fn.collect_params() v_w, v_b = param_di['W'], param_di['b'] assert v_w.lr_mult == LR_MULT From a12401a3c4db0fa10f416914ef429a03cf0c68c6 Mon Sep 17 00:00:00 2001 From: khaotik Date: Mon, 30 May 2022 09:26:11 +0800 Subject: [PATCH 5/5] device -> ctx for legacy ndarray --- tests/python/unittest/test_gluon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index a33babd2f20b..952e9cf95c30 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -237,7 +237,7 @@ def test_symbol_block_init(): fn = mx.gluon.SymbolBlock([s_y], [s_x]) fn.initialize() v_x = mx.nd.random_uniform(-1., 1., shape=(1,256), dtype=DTYPE, - device=mx.device.current_device()) + ctx=mx.device.current_device()) fn.forward(v_x) param_di = fn.collect_params() v_w, v_b = param_di['W'], param_di['b']