diff --git a/python/mxnet/contrib/text/embedding.py b/python/mxnet/contrib/text/embedding.py index c08ef8452138..47d6c7e8ab14 100644 --- a/python/mxnet/contrib/text/embedding.py +++ b/python/mxnet/contrib/text/embedding.py @@ -226,7 +226,26 @@ def _get_pretrained_file(cls, embedding_root, pretrained_file_name): zf.extractall(embedding_dir) elif ext == '.gz': with tarfile.open(downloaded_file_path, 'r:gz') as tar: - tar.extractall(path=embedding_dir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=embedding_dir) return pretrained_file_path def _load_embedding(self, pretrained_file_path, elem_delim, init_unknown_vec, encoding='utf8'): diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index 70e3045e45cd..f627ca5ea7a0 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -181,7 +181,26 @@ def _get_data(self): sha1_hash=self._archive_file[1]) with tarfile.open(filename) as tar: - tar.extractall(self._root) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, self._root) if self._train: data_files = self._train_data diff --git a/tests/nightly/estimator/test_sentiment_rnn.py b/tests/nightly/estimator/test_sentiment_rnn.py index e481f70cda49..8cf0952db806 100644 --- a/tests/nightly/estimator/test_sentiment_rnn.py +++ b/tests/nightly/estimator/test_sentiment_rnn.py @@ -125,7 +125,26 @@ def download_imdb(data_dir='/tmp/data'): if not os.path.isfile(file_path): file_path = gluon.utils.download(url, data_dir, sha1_hash=sha1) with tarfile.open(file_path, 'r') as f: - f.extractall(data_dir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(f, data_dir) def read_imdb(folder='train'):