Skip to content

Commit

Permalink
Merge pull request #8 from galmetzer/master
Browse files Browse the repository at this point in the history
Beamgap loss and G Demo
  • Loading branch information
galmetzer authored Jun 3, 2020
2 parents c1a2079 + 9a7f26a commit c56d89b
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 6 deletions.
20 changes: 17 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from models.networks import init_net, sample_surface, local_nonuniform_penalty
import utils
import numpy as np
from models.losses import chamfer_distance
from models.losses import chamfer_distance, BeamGapLoss
from options import Options
import time
import os
Expand All @@ -30,6 +30,12 @@
print(f'number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)

beamgap_loss = BeamGapLoss(device)

if opts.beamgap_iterations > 0:
print('beamgap on')
beamgap_loss.update_pm(part_mesh, torch.cat([input_xyz, input_normals], dim=-1))

for i in range(opts.iterations):
num_samples = options.get_num_samples(i % opts.upsamp)
if opts.global_step:
Expand All @@ -45,7 +51,11 @@
recon_xyz, recon_normals = recon_xyz.type(options.dtype()), recon_normals.type(options.dtype())
xyz_chamfer_loss, normals_chamfer_loss = chamfer_distance(recon_xyz, input_xyz, x_normals=recon_normals, y_normals=input_normals,
unoriented=opts.unoriented)
loss = (xyz_chamfer_loss + (opts.ang_wt * normals_chamfer_loss))

if (i < opts.beamgap_iterations) and (i % opts.beamgap_modulo == 0):
loss = beamgap_loss(part_mesh, part_i)
else:
loss = (xyz_chamfer_loss + (opts.ang_wt * normals_chamfer_loss))
if opts.local_non_uniform > 0:
loss += opts.local_non_uniform * local_nonuniform_penalty(part_mesh.main_mesh).float()
loss.backward()
Expand All @@ -70,14 +80,18 @@
mesh = part_mesh.main_mesh
num_faces = int(np.clip(len(mesh.faces) * 1.5, len(mesh.faces), opts.max_faces))

if num_faces > len(mesh.faces):
if num_faces > len(mesh.faces) or opts.manifold_always:
# up-sample mesh
mesh = utils.manifold_upsample(mesh, opts.save_path, Mesh,
num_faces=min(num_faces, opts.max_faces),
res=opts.manifold_res, simplify=True)

part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'upsampled to {len(mesh.faces)} faces; number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)
if i < opts.beamgap_iterations:
print('beamgap updated')
beamgap_loss.update_pm(part_mesh, input_xyz)

with torch.no_grad():
mesh.export(os.path.join(opts.save_path, 'last_recon.obj'))
58 changes: 58 additions & 0 deletions models/layers/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import copy
from pathlib import Path
import pickle
from pytorch3d.ops.knn import knn_gather, knn_points


class Mesh:

Expand Down Expand Up @@ -37,6 +39,7 @@ def __init__(self, file, hold_history=False, vs=None, faces=None, device='cpu',
self.faces = torch.from_numpy(self.faces)
self.vs = self.vs.to(self.device)
self.faces = self.faces.to(self.device).long()
self.area, self.normals = self.face_areas_normals(self.vs, self.faces)

def build_gemm(self):
self.ve = [[] for _ in self.vs]
Expand Down Expand Up @@ -145,7 +148,62 @@ def normalize_unit_bb(self):
self.vs /= self.scale
self.vs += self.translations[None, :]

def discrete_project(self, pc: torch.Tensor, thres=0.9, cpu=False):
with torch.no_grad():
device = torch.device('cpu') if cpu else self.device
pc = pc.double()
if isinstance(self, Mesh):
mid_points = self.vs[self.faces].mean(dim=1)
normals = self.normals
else:
mid_points = self[:, :3]
normals = self[:, 3:]
pk12 = knn_points(mid_points[:, :3].unsqueeze(0), pc[:, :, :3], K=3).idx[0]
pk21 = knn_points(pc[:, :, :3], mid_points[:, :3].unsqueeze(0), K=3).idx[0]
loop = pk21[pk12].view(pk12.shape[0], -1)
knn_mask = (loop == torch.arange(0, pk12.shape[0], device=self.device)[:, None]).sum(dim=1) > 0
mid_points = mid_points.to(device)
pc = pc[0].to(device)
normals = normals.to(device)[~ knn_mask, :]
masked_mid_points = mid_points[~ knn_mask, :]
displacement = masked_mid_points[:, None, :] - pc[:, :3]
torch.cuda.empty_cache()
distance = displacement.norm(dim=-1)
mask = (torch.abs(torch.sum((displacement / distance[:, :, None]) *
normals[:, None, :], dim=-1)) > thres)
if pc.shape[-1] == 6:
pc_normals = pc[:, 3:]
normals_correlation = torch.sum(normals[:, None, :] * pc_normals, dim=-1)
mask = mask * (normals_correlation > 0)
torch.cuda.empty_cache()
distance[~ mask] += float('inf')
min, argmin = distance.min(dim=-1)

pc_per_face_masked = pc[argmin, :].clone()
pc_per_face_masked[min == float('inf'), :] = float('nan')
pc_per_face = torch.zeros(mid_points.shape[0], 6).\
type(pc_per_face_masked.dtype).to(pc_per_face_masked.device)
pc_per_face[~ knn_mask, :pc.shape[-1]] = pc_per_face_masked
pc_per_face[knn_mask, :] = float('nan')

# clean up
del knn_mask
return pc_per_face.to(self.device), (pc_per_face[:, 0] == pc_per_face[:, 0]).to(device)

@staticmethod
def face_areas_normals(vs, faces):
if type(vs) is not torch.Tensor:
vs = torch.from_numpy(vs)
if type(faces) is not torch.Tensor:
faces = torch.from_numpy(faces)
face_normals = torch.cross(vs[faces[:, 1]] - vs[faces[:, 0]],
vs[faces[:, 2]] - vs[faces[:, 1]])

face_areas = torch.norm(face_normals, dim=1)
face_normals = face_normals / face_areas[:, None]
face_areas = 0.5 * face_areas
face_areas = 0.5 * face_areas
return face_areas, face_normals

def update_verts(self, verts):
"""
Expand Down
39 changes: 39 additions & 0 deletions models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,42 @@ def chamfer_distance(
cham_normals = cham_norm_x + cham_norm_y if return_normals else None

return cham_dist, cham_normals


class ZeroNanGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, grad):
grad[grad != grad] = 0
return grad


class BeamGapLoss:
def __init__(self, device):
self.device = device
self.points, self.masks = None, None

def update_pm(self, pmesh, target_pc):
points, masks = [], []
target_pc.to(self.device)
total_mask = torch.zeros(pmesh.main_mesh.vs.shape[0])
for i, m in enumerate(pmesh):
p, mask = m.discrete_project(target_pc, thres=0.99, cpu=True)
p, mask = p.to(target_pc.device), mask.to(target_pc.device)
points.append(p[:, :3])
masks.append(mask)
temp = torch.zeros(m.vs.shape[0])
if (mask != False).any():
temp[m.faces[mask]] = 1
total_mask[pmesh.sub_mesh_index[i]] += temp
self.points, self.masks = points, masks

def __call__(self, pmesh, j):
losses = self.points[j] - pmesh[j].vs[pmesh[j].faces].mean(dim=1)
losses = ZeroNanGrad.apply(losses)
losses = torch.norm(losses, dim=1)[self.masks[j]]
l2 = losses.mean().float()
return l2 * 1e1
7 changes: 7 additions & 0 deletions options.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def parse_args(self):
parser.add_argument('--init-weights', type=float, default=0.002, help='initialize NN with this size')
#
parser.add_argument('--export-interval', type=int, metavar='N', default=100, help='export interval')
parser.add_argument('--beamgap-iterations', type=int, default=0,
help='the # iters to which the beamgap loss will be calculated')
parser.add_argument('--beamgap-modulo', type=int, default=1, help='skip iterations with beamgap loss'
'; calc beamgap when:'
' iter % (--beamgap-modulo) == 0')
parser.add_argument('--manifold-always', action='store_true',
help='always run manifold even when the maximum number of faces is reached')

self.args = parser.parse_args()

Expand Down
14 changes: 14 additions & 0 deletions scripts/examples/g.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
python main.py --lr 0.0001 \
--input-pc ./data/g.ply \
--upsamp 100 \
--initial-mesh ./data/g_initmesh.obj \
--save-path checkpoints/g \
--iterations 3000 \
--beamgap-iterations 800 \
--upsamp 100 \
--beamgap-modulo 2 \
--manifold-res 4000 \
--convs 64 64 64 128 \
--pools 0 0 0 \
--max-faces 10000 \
--manifold-always
12 changes: 9 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import numpy as np
import glob
import os
import uuid
from options import MANIFOLD_DIR
Expand All @@ -12,12 +11,16 @@ def manifold_upsample(mesh, save_path, Mesh, num_faces=2000, res=3000, simplify=

temp_file = os.path.join(save_path, random_file_name('obj'))
opts = ' ' + str(res) if res is not None else ''
cmd = "{} {} {}".format(os.path.join(MANIFOLD_DIR, 'manifold'), fname, temp_file + opts)

manifold_script_path = os.path.join(MANIFOLD_DIR, 'manifold')
if not os.path.exists(manifold_script_path):
raise FileNotFoundError(f'{manifold_script_path} not found')
cmd = "{} {} {}".format(manifold_script_path, fname, temp_file + opts)
os.system(cmd)

if simplify:
cmd = "{} -i {} -o {} -f {}".format(os.path.join(MANIFOLD_DIR, 'simplify'), temp_file,
temp_file, num_faces)
temp_file, num_faces)
os.system(cmd)

m_out = Mesh(temp_file, hold_history=True, device=mesh.device)
Expand All @@ -26,6 +29,7 @@ def manifold_upsample(mesh, save_path, Mesh, num_faces=2000, res=3000, simplify=
[os.remove(_) for _ in list(glob.glob(os.path.splitext(temp_file)[0] + '*'))]
return m_out


def read_pts(pts_file):
'''
:param pts_file: file path of a plain text list of points
Expand All @@ -48,6 +52,7 @@ def read_pts(pts_file):
pass
return np.array(xyz, dtype=np.float32), np.array(normals, dtype=np.float32)


def load_obj(file):
vs, faces = [], []
f = open(file)
Expand All @@ -70,6 +75,7 @@ def load_obj(file):
assert np.logical_and(faces >= 0, faces < len(vs)).all()
return vs, faces


def export(file, vs, faces, vn=None, color=None):
with open(file, 'w+') as f:
for vi, v in enumerate(vs):
Expand Down

0 comments on commit c56d89b

Please sign in to comment.