"""
Nvdiffrast-based renderers, equivalent to those based on PyTorch3D but sometimes faster.
This module imports correctly if nvdiffrast is missing, but the renderer will raise the NvdiffrastNotFound exception.
"""
from dataclasses import dataclass
from typing import Optional
import logging
import torch
from torch.nn import functional as F
try:
import nvdiffrast.torch as dr
except ImportError:
dr = None
is_available = False
from torchdrivesim.mesh import RGBMesh
from torchdrivesim.rendering.base import RendererConfig, BirdviewRenderer, Cameras
from torchdrivesim.utils import Resolution
[docs]logger = logging.getLogger(__name__)
[docs]class NvdiffrastNotFound(ImportError):
"""
Nvdiffrast is not installed.
"""
pass
[docs]def get_glctx_session(device, opengl=True):
if not is_available:
raise NvdiffrastNotFound()
device = torch.device(device)
if device.type == 'cpu':
logger.debug('\'nvdiffrast\' supports only rendering on GPU.')
return None
global glctx_sessions
if device not in glctx_sessions:
if opengl or not hasattr(dr, 'RasterizeCudaContext'):
glctx_sessions[device] = dr.RasterizeGLContext(device=device, output_db=False)
else:
glctx_sessions[device] = dr.RasterizeCudaContext(device=device)
else:
if hasattr(dr, 'RasterizeCudaContext'):
if opengl and isinstance(glctx_sessions[device], dr.RasterizeCudaContext):
glctx_sessions[device] = dr.RasterizeGLContext(device=device, output_db=False)
elif not opengl and isinstance(glctx_sessions[device], dr.RasterizeGLContext):
glctx_sessions[device] = dr.RasterizeCudaContext(device=device)
return glctx_sessions[device]
@dataclass
[docs]class NvdiffrastRendererConfig(RendererConfig):
"""
Configuration of nvdiffrast-based renderer.
"""
[docs] backend: str = 'nvdiffrast'
[docs] antialias: bool = False
[docs] opengl: bool = False #: if False, use CUDA for rendering
[docs] max_minibatch_size: Optional[int] = None #: used to pre-allocate memory, which may speed up rendering
[docs]class NvdiffrastRenderer(BirdviewRenderer):
"""
Similar to PyTorch3DRenderer, and producing indistinguishable images, but sometimes faster.
Note that nvdiffrast requires separate installation and is subject to its own license terms.
"""
def __init__(self, cfg: NvdiffrastRendererConfig, *args, **kwargs):
if not is_available:
raise NvdiffrastNotFound()
super().__init__(cfg, *args, **kwargs)
[docs] self.cfg: NvdiffrastRendererConfig = cfg
[docs] self.glctx = get_glctx_session('cuda' if cfg.device is None else cfg.device, opengl=self.cfg.opengl) # nvdiffrast is only available for cuda
if self.glctx is None:
raise RuntimeError('Failed to obtain glctx session for nvdiffrast')
[docs] def render_rgb_mesh(self, mesh: RGBMesh, res: Resolution, cameras: Cameras) -> torch.Tensor:
if mesh.device == torch.device('cpu'):
raise RuntimeError('Nvdiffrast does not support CPU rendering; please move the mesh to the GPU.')
if self.cfg.shift_mesh_by_camera_before_rendering:
mesh = mesh.translate(-cameras.xy)
cameras = Cameras(xy=torch.zeros_like(cameras.xy), sc=cameras.sc, scale=cameras.scale)
if not hasattr(self.glctx, 'initial_dummy_frame_rendered') and \
self.cfg.max_minibatch_size is not None:
maximum_min_batch_size = self.cfg.max_minibatch_size
dummy_verts = torch.Tensor([[0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1]]).to(mesh.device)
dummy_faces = torch.IntTensor([[0, 1, 2]]).to(mesh.device)
dummy_ranges = torch.IntTensor([[0, 1]]).expand(maximum_min_batch_size, -1).contiguous()
_, _ = dr.rasterize(self.glctx, dummy_verts, dummy_faces, resolution=[self.res.height, self.res.width],
ranges=dummy_ranges)
self.glctx.initial_dummy_frame_rendered = True
verts_clip = cameras.project_world_to_clip_space(mesh.verts)
verts_clip = verts_clip.reshape(-1, 4)
tris_first_idx = torch.arange(mesh.batch_size, device='cpu', dtype=torch.int32) * mesh.faces_count
tris_count = torch.ones_like(tris_first_idx, device='cpu') * mesh.faces_count
ranges = torch.stack([tris_first_idx, tris_count], dim=-1)
faces_offset = torch.arange(mesh.batch_size, device=mesh.faces.device, dtype=torch.int32) * mesh.verts_count
faces = (mesh.faces + faces_offset[:, None, None]).reshape(-1, 3).to(torch.int32)
vertices_attributes = mesh.attrs.reshape(-1, 3)
rast, _ = dr.rasterize(self.glctx, verts_clip, faces, resolution=[res.height, res.width],
ranges=ranges)
image, _ = dr.interpolate(vertices_attributes, rast, faces)
# Change background color in case it's not black
if sum(self.get_color('background')) != 0:
image = torch.where(rast[..., 3:4] > 0, image,
torch.tensor([x / 255.0 for x in self.get_color('background')],
device=image.device))
if self.cfg.antialias:
image = dr.antialias(image, rast, verts_clip, faces)
image = image[..., :3] * 255
image = image.transpose(-2, -3) # point x upwards, flip to right-handed coordinate frame
if self.cfg.left_handed_coordinates:
image = image.flip(dims=(-2,)) # flip horizontally
return image