-
Notifications
You must be signed in to change notification settings - Fork 951
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move model artifacts to github; Add progress bar
- Loading branch information
Showing
4 changed files
with
123 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import hashlib | ||
import os | ||
import shutil | ||
import sys | ||
import tempfile | ||
|
||
from urllib.request import urlopen, Request | ||
|
||
try: | ||
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available | ||
except ImportError: | ||
try: | ||
from tqdm import tqdm | ||
except ImportError: | ||
# fake tqdm if it's not installed | ||
class tqdm(object): # type: ignore | ||
|
||
def __init__(self, total=None, disable=False, | ||
unit=None, unit_scale=None, unit_divisor=None): | ||
self.total = total | ||
self.disable = disable | ||
self.n = 0 | ||
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm | ||
|
||
def update(self, n): | ||
if self.disable: | ||
return | ||
|
||
self.n += n | ||
if self.total is None: | ||
sys.stderr.write("\r{0:.1f} bytes".format(self.n)) | ||
else: | ||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | ||
sys.stderr.flush() | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
if self.disable: | ||
return | ||
|
||
sys.stderr.write('\n') | ||
|
||
|
||
def download_url_to_file(url, dst, hash_prefix=None, progress=True): | ||
r"""Download object at the given URL to a local path. | ||
Args: | ||
url (string): URL of the object to download | ||
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` | ||
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. | ||
Default: None | ||
progress (bool, optional): whether or not to display a progress bar to stderr | ||
Default: True | ||
Example: | ||
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') | ||
""" | ||
file_size = None | ||
# We use a different API for python2 since urllib(2) doesn't recognize the CA | ||
# certificates in older Python | ||
req = Request(url, headers={"User-Agent": "torch.hub"}) | ||
u = urlopen(req) | ||
meta = u.info() | ||
if hasattr(meta, 'getheaders'): | ||
content_length = meta.getheaders("Content-Length") | ||
else: | ||
content_length = meta.get_all("Content-Length") | ||
if content_length is not None and len(content_length) > 0: | ||
file_size = int(content_length[0]) | ||
|
||
# We deliberately save it in a temp file and move it after | ||
# download is complete. This prevents a local working checkpoint | ||
# being overridden by a broken download. | ||
dst = os.path.expanduser(dst) | ||
dst_dir = os.path.dirname(dst) | ||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | ||
|
||
try: | ||
if hash_prefix is not None: | ||
sha256 = hashlib.sha256() | ||
with tqdm(total=file_size, disable=not progress, | ||
unit='B', unit_scale=True, unit_divisor=1024) as pbar: | ||
while True: | ||
buffer = u.read(8192) | ||
if len(buffer) == 0: | ||
break | ||
f.write(buffer) | ||
if hash_prefix is not None: | ||
sha256.update(buffer) | ||
pbar.update(len(buffer)) | ||
|
||
f.close() | ||
if hash_prefix is not None: | ||
digest = sha256.hexdigest() | ||
if digest[:len(hash_prefix)] != hash_prefix: | ||
raise RuntimeError('invalid hash value (expected "{}", got "{}")' | ||
.format(hash_prefix, digest)) | ||
shutil.move(f.name, dst) | ||
finally: | ||
f.close() | ||
if os.path.exists(f.name): | ||
os.remove(f.name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters