Skip to content

Commit

Permalink
Move model artifacts to github; Add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
timesler committed Jul 5, 2020
1 parent 5dcce36 commit 3871733
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 20 deletions.
30 changes: 12 additions & 18 deletions models/inception_resnet_v1.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import requests
from requests.adapters import HTTPAdapter

import torch
from torch import nn
from torch.nn import functional as F
import requests
from requests.adapters import HTTPAdapter
import os

from .utils.download import download_url_to_file


class BasicConv2d(nn.Module):
Expand Down Expand Up @@ -310,29 +313,20 @@ def load_weights(mdl, name):
ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
"""
if name == 'vggface2':
features_path = 'https://drive.google.com/uc?export=download&id=1cWLH_hPns8kSfMz9kKl9PsG5aNV2VSMn'
logits_path = 'https://drive.google.com/uc?export=download&id=1mAie3nzZeno9UIzFXvmVZrDG3kwML46X'
path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt'
elif name == 'casia-webface':
features_path = 'https://drive.google.com/uc?export=download&id=1LSHHee_IQj5W3vjBcRyVaALv4py1XaGy'
logits_path = 'https://drive.google.com/uc?export=download&id=1QrhPgn1bGlDxAil2uc07ctunCQoDnCzT'
path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt'
else:
raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')

model_dir = os.path.join(get_torch_home(), 'checkpoints')
os.makedirs(model_dir, exist_ok=True)

state_dict = {}
for i, path in enumerate([features_path, logits_path]):
cached_file = os.path.join(model_dir, '{}_{}.pt'.format(name, path[-10:]))
if not os.path.exists(cached_file):
print('Downloading parameters ({}/2)'.format(i+1))
s = requests.Session()
s.mount('https://', HTTPAdapter(max_retries=10))
r = s.get(path, allow_redirects=True)
with open(cached_file, 'wb') as f:
f.write(r.content)
state_dict.update(torch.load(cached_file))
cached_file = os.path.join(model_dir, os.path.basename(path))
if not os.path.exists(cached_file):
download_url_to_file(path, cached_file)

state_dict = torch.load(cached_file)
mdl.load_state_dict(state_dict)


Expand Down
7 changes: 6 additions & 1 deletion models/utils/detect_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
from torch.nn.functional import interpolate
from torchvision.transforms import functional as F
from torchvision.ops.boxes import batched_nms
import cv2
from PIL import Image
import numpy as np
import os

# OpenCV is optional, but required if using numpy arrays instead of PIL
try:
import cv2
except:
pass


def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device):
if isinstance(imgs, (np.ndarray, torch.Tensor)):
Expand Down
102 changes: 102 additions & 0 deletions models/utils/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import hashlib
import os
import shutil
import sys
import tempfile

from urllib.request import urlopen, Request

try:
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
except ImportError:
try:
from tqdm import tqdm
except ImportError:
# fake tqdm if it's not installed
class tqdm(object): # type: ignore

def __init__(self, total=None, disable=False,
unit=None, unit_scale=None, unit_divisor=None):
self.total = total
self.disable = disable
self.n = 0
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm

def update(self, n):
if self.disable:
return

self.n += n
if self.total is None:
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
else:
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
sys.stderr.flush()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.disable:
return

sys.stderr.write('\n')


def download_url_to_file(url, dst, hash_prefix=None, progress=True):
r"""Download object at the given URL to a local path.
Args:
url (string): URL of the object to download
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
Default: None
progress (bool, optional): whether or not to display a progress bar to stderr
Default: True
Example:
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
"""
file_size = None
# We use a different API for python2 since urllib(2) doesn't recognize the CA
# certificates in older Python
req = Request(url, headers={"User-Agent": "torch.hub"})
u = urlopen(req)
meta = u.info()
if hasattr(meta, 'getheaders'):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])

# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)

try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(len(buffer))

f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
.format(hash_prefix, digest))
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import setuptools, os

PACKAGE_NAME = 'facenet-pytorch'
VERSION = '2.2.9'
VERSION = '2.3.0'
AUTHOR = 'Tim Esler'
EMAIL = 'tim.esler@gmail.com'
DESCRIPTION = 'Pretrained Pytorch face detection and recognition models'
Expand Down Expand Up @@ -38,5 +38,7 @@
install_requires=[
'numpy',
'requests',
'torchvision',
'pillow',
],
)

0 comments on commit 3871733

Please sign in to comment.