Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Rewrite check_binary_symbols as Python script #1978

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 1 addition & 74 deletions check_binary.sh
Original file line number Diff line number Diff line change
Expand Up @@ -123,81 +123,8 @@ if [[ "$(uname)" != 'Darwin' ]]; then

# We also check that there are [not] cxx11 symbols in libtorch
#
# To check whether it is using cxx11 ABI, check non-existence of symbol:
PRE_CXX11_SYMBOLS=(
"std::basic_string<"
"std::list"
)
# To check whether it is using pre-cxx11 ABI, check non-existence of symbol:
CXX11_SYMBOLS=(
"std::__cxx11::basic_string"
"std::__cxx11::list"
)
# NOTE: Checking the above symbols in all namespaces doesn't work, because
# devtoolset7 always produces some cxx11 symbols even if we build with old ABI,
# and CuDNN always has pre-cxx11 symbols even if we build with new ABI using gcc 5.4.
# Instead, we *only* check the above symbols in the following namespaces:
LIBTORCH_NAMESPACE_LIST=(
"c10::"
"at::"
"caffe2::"
"torch::"
)
echo "Checking that symbols in libtorch.so have the right gcc abi"
grep_symbols () {
symbols=("$@")
for namespace in "${LIBTORCH_NAMESPACE_LIST[@]}"
do
for symbol in "${symbols[@]}"
do
nm "$lib" | c++filt | grep " $namespace".*$symbol
done
done
}
check_lib_symbols_for_abi_correctness () {
lib=$1
echo "lib: " $lib
if [[ "$DESIRED_DEVTOOLSET" == *"cxx11-abi"* ]]; then
num_pre_cxx11_symbols=$(grep_symbols "${PRE_CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_pre_cxx11_symbols: " $num_pre_cxx11_symbols
if [[ "$num_pre_cxx11_symbols" -gt 0 ]]; then
echo "Found pre-cxx11 symbols but there shouldn't be. Dumping symbols"
grep_symbols "${PRE_CXX11_SYMBOLS[@]}"
exit 1
fi
num_cxx11_symbols=$(grep_symbols "${CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_cxx11_symbols: " $num_cxx11_symbols
if [[ "$num_cxx11_symbols" -lt 1000 ]]; then
echo "Didn't find enough cxx11 symbols. Aborting."
exit 1
fi
else
num_cxx11_symbols=$(grep_symbols "${CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_cxx11_symbols: " $num_cxx11_symbols
if [[ "$num_cxx11_symbols" -gt 0 ]]; then
echo "Found cxx11 symbols but there shouldn't be. Dumping symbols"
grep_symbols "${CXX11_SYMBOLS[@]}"
exit 1
fi
num_pre_cxx11_symbols=$(grep_symbols "${PRE_CXX11_SYMBOLS[@]}" | wc -l) || true
echo "num_pre_cxx11_symbols: " $num_pre_cxx11_symbols
if [[ "$num_pre_cxx11_symbols" -lt 1000 ]]; then
echo "Didn't find enough pre-cxx11 symbols. Aborting."
exit 1
fi
fi
}
# After https://github.com/pytorch/pytorch/pull/29731 most of the real
# libtorch code will live in libtorch_cpu, not libtorch, so cxx11
# symbol counting won't work on libtorch (since there's nothing in
# it.) Fortunately, libtorch_cpu.so doesn't exist prior to this PR,
# so just test if the file exists and use it if it does.
if [ -f "${install_root}/lib/libtorch_cpu.so" ]; then
libtorch="${install_root}/lib/libtorch_cpu.so"
else
libtorch="${install_root}/lib/libtorch.so"
fi
check_lib_symbols_for_abi_correctness $libtorch
python test/check_binary_symbols.py

echo "cxx11 symbols seem to be in order"
fi # if on Darwin
Expand Down
90 changes: 90 additions & 0 deletions test/check_binary_symbols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python3
import concurrent.futures
import distutils.sysconfig
import itertools
import functools
import os
import re
from pathlib import Path

# We also check that there are [not] cxx11 symbols in libtorch
#
# To check whether it is using cxx11 ABI, check non-existence of symbol:
PRE_CXX11_SYMBOLS=(
"std::basic_string<",
"std::list",
)
# To check whether it is using pre-cxx11 ABI, check non-existence of symbol:
CXX11_SYMBOLS=(
"std::__cxx11::basic_string",
"std::__cxx11::list",
)
# NOTE: Checking the above symbols in all namespaces doesn't work, because
# devtoolset7 always produces some cxx11 symbols even if we build with old ABI,
# and CuDNN always has pre-cxx11 symbols even if we build with new ABI using gcc 5.4.
# Instead, we *only* check the above symbols in the following namespaces:
LIBTORCH_NAMESPACE_LIST=(
"c10::",
"at::",
"caffe2::",
"torch::",
)

LIBTORCH_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, CXX11_SYMBOLS)]

LIBTORCH_PRE_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, PRE_CXX11_SYMBOLS)]

@functools.lru_cache
def get_symbols(lib :str ) -> list[tuple[str, str, str]]:
from subprocess import check_output
lines = check_output(f'nm "{lib}"|c++filt', shell=True)
return [x.split(' ', 2) for x in lines.decode('latin1').split('\n')[:-1]]


def grep_symbols(lib: str, patterns: list[re.Match]) -> list[str]:
def _grep_symbols(symbols: list[tuple[str, str, str]], patterns: list[re.Match]) -> list[str]:
rc = []
for s_addr, s_type, s_name in symbols:
for pattern in patterns:
if pattern.match(s_name):
rc.append(s_name)
continue
return rc
all_symbols = get_symbols(lib)
num_workers= 32
chunk_size = (len(all_symbols) + num_workers - 1 ) // num_workers
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
tasks = [executor.submit(_grep_symbols, all_symbols[i * chunk_size : (i + 1) * chunk_size], patterns) for i in range(num_workers)]
return sum((x.result() for x in tasks), [])

def check_lib_symbols_for_abi_correctness(lib: str, pre_cxx11_abi: bool = True) -> None:
print(f"lib: {lib}")
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
pre_cxx11_symbols = grep_symbols(lib, LIBTORCH_PRE_CXX11_PATTERNS)
num_cxx11_symbols = len(cxx11_symbols)
num_pre_cxx11_symbols = len(pre_cxx11_symbols)
print(f"num_cxx11_symbols: {num_cxx11_symbols}")
print(f"num_pre_cxx11_symbols: {num_pre_cxx11_symbols}")
if pre_cxx11_abi:
if num_cxx11_symbols > 0:
raise RuntimeError(f"Found cxx11 symbols, but there shouldn't be any, see: {cxx11_symbols[:100]}")
if num_pre_cxx11_symbols < 1000:
raise RuntimeError("Didn't find enough pre-cxx11 symbols.")
else:
if num_pre_cxx11_symbols > 0:
raise RuntimeError(f"Found pre-cxx11 symbols, but there shouldn't be any, see: {pre_cxx11_symbols[:100]}")
if num_cxx11_symbols < 100:
raise RuntimeError("Didn't find enought cxx11 symbols")

def main() -> None:
if os.getenv("PACKAGE_TYPE") == "libtorch":
install_root = Path(__file__).parent.parent
else:
install_root = Path(distutils.sysconfig.get_python_lib()) / "torch"
libtorch_cpu_path = install_root / "lib" / "libtorch_cpu.so"
pre_cxx11_abi = "cxx11-abi" not in os.getenv("DESIRED_DEVTOOLSET", "")
check_lib_symbols_for_abi_correctness(libtorch_cpu_path, pre_cxx11_abi)


if __name__ == "__main__":
main()
Loading