Source code for torchdrivesim.rendering.pytorch3d

"""
PyTorch3D-based renderers, used by default in TorchDriveSim.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Tuple

try:
    import pytorch3d
    import pytorch3d.renderer
[docs] is_available = True
except ImportError: pytorch3d = None is_available = False import torch from torch.nn import functional as F from torchdrivesim.mesh import RGBMesh from torchdrivesim.rendering.base import RendererConfig, BirdviewRenderer, Cameras from torchdrivesim.utils import Resolution
[docs]class Pytorch3DNotFound(ImportError): pass
[docs]class RenderingBlend(Enum): """ Blending choices from pytorch3d. May be ignored by other renderer types. https://pytorch3d.readthedocs.io/en/latest/modules/renderer/blending.html """
[docs] hard = 'hard'
[docs] soft = 'soft'
[docs] sigmoid = 'sigmoid'
@dataclass
[docs]class Pytorch3DRendererConfig(RendererConfig): """ Configuration of pytorch3d-based renderer. """
[docs] backend: str = 'pytorch3d'
[docs] differentiable_rendering: RenderingBlend = RenderingBlend.soft
[docs]class Shader2D(torch.nn.Module): """ Shader that ignores lighting, based on https://github.com/facebookresearch/pytorch3d/issues/84 """ def __init__(self, device="cpu", background_color: Tuple[float, float, float] = (0, 0, 0), blend=RenderingBlend.soft): super().__init__()
[docs] self.background_color = background_color
[docs] self.blend = blend
[docs] def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: if not is_available: raise Pytorch3DNotFound() from pytorch3d.renderer import BlendParams pixel_colors = meshes.sample_textures(fragments) if self.blend == RenderingBlend.soft: images = pytorch3d.renderer.softmax_rgb_blend(pixel_colors, fragments, BlendParams(background_color=self.background_color)) elif self.blend == RenderingBlend.hard: images = pytorch3d.renderer.hard_rgb_blend(pixel_colors, fragments, BlendParams(background_color=self.background_color)) elif self.blend == RenderingBlend.sigmoid: images = pytorch3d.renderer.sigmoid_alpha_blend(pixel_colors, fragments, BlendParams(background_color=self.background_color)) else: raise ValueError("Unrecognized blend type: " + str(self.blend)) return images
[docs]class Pytorch3DRenderer(BirdviewRenderer): """ Renderer based on pytorch3d, using an orthographic projection and a trivial shader. Works on both GPU and CPU, but CPU is very slow. """ def __init__(self, cfg: Pytorch3DRendererConfig, *args, **kwargs): super().__init__(cfg, *args, **kwargs)
[docs] self.cfg: Pytorch3DRendererConfig = cfg
[docs] self.renderer = self.make_renderer( self.res, blend=self.cfg.differentiable_rendering, background_color=tuple([x / 255.0 for x in self.get_color('background')]) )
[docs] def render_rgb_mesh(self, mesh: RGBMesh, res: Resolution, cameras: Cameras) -> torch.Tensor: if self.cfg.shift_mesh_by_camera_before_rendering: mesh = mesh.translate(-cameras.xy) cameras = Cameras(xy=torch.zeros_like(cameras.xy), sc=cameras.sc, scale=cameras.scale) meshes = mesh.pytorch3d() if res != self.res: renderer = self.make_renderer(res, self.cfg.differentiable_rendering, tuple([x / 255.0 for x in self.get_color('background')])) else: renderer = self.renderer renderer.rasterizer.cameras = construct_pytorch3d_cameras(cameras) image = renderer(meshes) image = image[..., :3] * 255 image = image.transpose(-2, -3) # point x upwards, flip to right-handed coordinate frame if self.cfg.left_handed_coordinates: image = image.flip(dims=(-2,)) # flip horizontally return image
@classmethod
[docs] def make_renderer(cls, res, blend, background_color): if not is_available: raise Pytorch3DNotFound() settings = pytorch3d.renderer.mesh.rasterizer.RasterizationSettings( image_size=res.height, blur_radius=0.0, faces_per_pixel=1, ) renderer = pytorch3d.renderer.MeshRenderer( rasterizer=pytorch3d.renderer.MeshRasterizer( raster_settings=settings ), shader=Shader2D(background_color=background_color, blend=blend) # type: ignore ) return renderer
[docs]def construct_pytorch3d_cameras(cameras: Cameras) -> "pytorch3d.renderer.FoVOrthographicCameras": if not is_available: raise Pytorch3DNotFound() xy, sc, scale = cameras.xy, cameras.sc, cameras.scale assert xy.shape == sc.shape device = xy.device cs_neg = torch.flip(sc, dims=(-1,)) * torch.tensor([[1, -1]], dtype=sc.dtype, device=sc.device) rotation_matrix = torch.stack([cs_neg, sc], dim=-2) # pytorch3d seems to rotate the provided translation vector reverse_rotation = rotation_matrix.transpose(-1, -2) rotated_translation = reverse_rotation.matmul(-xy.unsqueeze(-1)).squeeze(-1) t = F.pad(rotated_translation, (0, 1), mode='constant', value=0) r = F.pad(rotation_matrix, (0, 1, 0, 1), mode='constant', value=0) r[..., -1, -1] = 1 scale = torch.tensor([[scale, scale, 1]], dtype=xy.dtype, device=device).expand_as(t) cameras = pytorch3d.renderer.FoVOrthographicCameras(device=device, scale_xyz=scale, T=t, R=r) return cameras