Skip to content

Commit

Permalink
Fix logic and name for nonexistent arch test
Browse files Browse the repository at this point in the history
  • Loading branch information
gmarkall committed Mar 5, 2024
1 parent f574d3c commit ee087c8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
18 changes: 18 additions & 0 deletions pynvjitlink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def alt_gpu_compute_capability(gpu_compute_capability):
return (7, 0)


@pytest.fixture(scope="session")
def absent_gpu_compute_capability(gpu_compute_capability, alt_gpu_compute_capability):
"""A compute capability that does not match the current GPU"""
# A compute capability not used in any cubin or fatbin test binary
cc_majors = {6, 7, 8}
cc_majors.remove(gpu_compute_capability[0])
cc_majors.remove(alt_gpu_compute_capability[0])
return (cc_majors.pop(), 0)


@pytest.fixture(scope="session")
def gpu_arch_flag(gpu_compute_capability):
"""nvJitLink arch flag to link for the current GPU"""
Expand All @@ -40,6 +50,14 @@ def alt_gpu_arch_flag(alt_gpu_compute_capability):
return f"-arch=sm_{major}{minor}"


@pytest.fixture(scope="session")
def absent_gpu_arch_flag(absent_gpu_compute_capability):
"""nvJitLink arch flag to link for an architecture not in any cubin or
fatbin"""
major, minor = absent_gpu_compute_capability
return f"-arch=sm_{major}{minor}"


@pytest.fixture(scope="session")
def device_functions_archive():
test_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down
6 changes: 4 additions & 2 deletions pynvjitlink/tests/test_pynvjitlink_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def test_add_fatbin_arch_2(device_functions_fatbin, alt_gpu_arch_flag):
nvjitlinker.add_fatbin(device_functions_fatbin, name)


def test_add_incompatible_fatbin_arch_error(device_functions_fatbin, alt_gpu_arch_flag):
nvjitlinker = NvJitLinker(alt_gpu_arch_flag)
def test_add_nonexistent_fatbin_arch_error(
device_functions_fatbin, absent_gpu_arch_flag
):
nvjitlinker = NvJitLinker(absent_gpu_arch_flag)
name = "test_device_functions.fatbin"
with pytest.raises(NvJitLinkError, match="NVJITLINK_ERROR_INVALID_INPUT error"):
nvjitlinker.add_fatbin(device_functions_fatbin, name)
Expand Down

0 comments on commit ee087c8

Please sign in to comment.