Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error running examples #9

Open
le000043 opened this issue May 24, 2021 · 1 comment
Open

Error running examples #9

le000043 opened this issue May 24, 2021 · 1 comment

Comments

@le000043
Copy link

le000043 commented May 24, 2021

Hello, I encountered this error when running this command :
python3 -m sam.sam_jax.train --dataset cifar10 --model_name WideResnet28x10 --output_dir /tmp/my_experiment --image_level_augmentations autoaugment --num_epochs 1 --sam_rho 0.05

Any helps would be greatly appreaciated

Traceback (most recent call last):
  File "/home/dat/sam/sam/sam_jax/train.py", line 160, in main
    model, state = load_imagenet_model.get_model(FLAGS.model_name,
  File "/home/dat/sam/sam/sam_jax/imagenet_models/load_model.py", line 129, in get_model
    raise ModelNameError('Unrecognized model name.')
sam.sam_jax.imagenet_models.load_model.ModelNameError: Unrecognized model name.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/dat/sam/sam/sam_jax/train.py", line 177, in <module>
    app.run(main)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/dat/sam/sam/sam_jax/train.py", line 164, in main
    model, state = load_model.get_model(FLAGS.model_name,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 118, in get_model
    model, init_state = create_image_model(prng_key, batch_size, image_size,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 57, in create_image_model
    _, initial_params = module.init_by_shape(
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 493, in init_by_shape
    stochastic_rng = stochastic.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 86, in make_rng
    return rng_frame.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 47, in make_rng
    return random.fold_in(self.base_rng, self.counter)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/random.py", line 289, in fold_in
    return _fold_in(key, jnp.uint32(data))
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 143, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/api.py", line 426, in cache_miss
    out_flat = xla.xla_call(
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 1565, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 1556, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 1568, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 609, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 874, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: CUDA operation failed: cudaGetErrorString symbol not found.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/dat/sam/sam/sam_jax/train.py", line 177, in <module>
    app.run(main)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/dat/sam/sam/sam_jax/train.py", line 164, in main
    model, state = load_model.get_model(FLAGS.model_name,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 118, in get_model
    model, init_state = create_image_model(prng_key, batch_size, image_size,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 57, in create_image_model
    _, initial_params = module.init_by_shape(
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 493, in init_by_shape
    stochastic_rng = stochastic.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 86, in make_rng
    return rng_frame.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 47, in make_rng
    return random.fold_in(self.base_rng, self.counter)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/random.py", line 289, in fold_in
    return _fold_in(key, jnp.uint32(data))
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 874, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: CUDA operation failed: cudaGetErrorString symbol not found.
@ssbin4
Copy link

ssbin4 commented Feb 15, 2023

Hi, I am getting a similar error message. Did you address the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants