from dataclasses import dataclass
import numpy as np
import torch
from torchdrivesim.mesh import BirdviewMesh, RGBMesh, tensor_color
from torchdrivesim.rendering import BirdviewRenderer, RendererConfig
from torchdrivesim.rendering.base import Cameras
from torchdrivesim.utils import Resolution
@dataclass
[docs]class CV2RendererConfig(RendererConfig):
[docs] trim_mesh_before_rendering: bool = True
[docs]class CV2Renderer(BirdviewRenderer):
"""
Renderer based on OpenCV. Slow, but easy to install. Renders on CPU.
"""
def __init__(self, cfg: CV2RendererConfig, *args, **kwargs):
import cv2
super().__init__(cfg, *args, **kwargs)
[docs] self.cfg: CV2RendererConfig = cfg
[docs] def render_rgb_mesh(self, mesh: RGBMesh, res: Resolution, cameras: Cameras) -> torch.Tensor:
import cv2
if self.cfg.shift_mesh_by_camera_before_rendering:
mesh = mesh.translate(-cameras.xy)
cameras = Cameras(xy=torch.zeros_like(cameras.xy), sc=cameras.sc, scale=cameras.scale)
if self.cfg.trim_mesh_before_rendering:
# For efficiency, remove faces that are not visible anyway
viewing_polygon = cameras.reverse_transform_points_screen(
torch.tensor([
[0, 0], [0, res.height], [res.width, res.height], [res.width, 0]
], device=mesh.device)[None].repeat_interleave(cameras.xy.shape[0], dim=0), res=res
)
viewing_polygon_center = viewing_polygon.mean(dim=1, keepdim=True)
viewing_polygon = viewing_polygon_center + (viewing_polygon - viewing_polygon_center) * 1.05 # safety margin
mesh = mesh.trim(viewing_polygon)
# We assume all vertices of the same face are on the same plane
rendering_order = mesh.verts[:, None, :, 2].expand(-1, mesh.faces.shape[1], -1).flatten(0,1)
rendering_order = torch.gather(rendering_order, dim=1,
index=mesh.faces.flatten(0,1))[..., 0].reshape(*mesh.faces.shape[:2])
rendering_order = rendering_order.argsort(dim=1, descending=True).unsqueeze(-1).cpu()
pixel_verts = cameras.transform_points_screen(mesh.verts[..., :2], res=res).cpu().to(torch.int32)
mesh_faces = mesh.faces.cpu().gather(dim=1, index=rendering_order.expand_as(mesh.faces))
color_attrs = (mesh.attrs * (1.0 - 1e-3) * 256).floor().to(torch.uint8).cpu()
image_batch = []
for batch_idx in range(mesh.batch_size):
image = np.zeros((res.height, res.width, 3), dtype=np.float32)
faces = mesh_faces[batch_idx]
for face_idx in range(faces.shape[0]):
polygon = pixel_verts[batch_idx, faces[face_idx]].numpy()
color = color_attrs[batch_idx, faces[face_idx][0]].numpy().tolist()
image = cv2.fillConvexPoly(img=image, points=polygon, color=color, shift=0, lineType=cv2.LINE_AA)
image = torch.from_numpy(image)
image = image.transpose(-2, -3) # point x upwards, flip to right-handed coordinate frame
if self.cfg.left_handed_coordinates:
image = image.flip(dims=(-2,)) # flip horizontally
image_batch.append(image)
images = torch.stack(image_batch, dim=0).to(mesh.device)
if self.cfg.left_handed_coordinates:
images = images.flip(dims=(-2,)) # flip horizontally
return images