Source code for torchdrivesim.map

import dataclasses
import json
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List
from functools import cached_property

import lanelet2
import numpy as np
import torch

import torchdrivesim
from torchdrivesim.lanelet2 import LaneletMap, find_lanelet_directions, load_lanelet_map, road_mesh_from_lanelet_map, lanelet_map_to_lane_mesh
from torchdrivesim.mesh import BirdviewMesh
from torchdrivesim.traffic_controls import BaseTrafficControl, TrafficLightControl, StopSignControl, YieldControl
from torchdrivesim.traffic_lights import TrafficLightController
from torchdrivesim.utils import normalize_angle

@dataclass
[docs]class Stopline:
[docs] actor_id: int
[docs] agent_type: str
[docs] x: float
[docs] y: float
[docs] length: float
[docs] width: float
[docs] orientation: float
[docs] def __post_init__(self): if self.agent_type == "traffic-light": self.agent_type = "traffic_light" elif self.agent_type == "stop-sign": self.agent_type = "stop_sign" elif self.agent_type in ["yield-sign", "yield"]: self.agent_type = "yield_sign"
@dataclass
[docs]class MapConfig: """ Encapsulates various map metadata, including where to find files defining the map. Map definition includes a coordinate frame and traffic signals. """
[docs] name: str
[docs] left_handed_coordinates: bool = False
[docs] center: Optional[Tuple[float, float]] = None
[docs] lanelet_path: Optional[str] = None
[docs] lanelet_map_origin: Tuple[float, float] = (0, 0)
[docs] mesh_path: Optional[str] = None
[docs] stoplines_path: Optional[str] = None
[docs] traffic_light_controller_path: Optional[str] = None
[docs] iai_location_name: Optional[str] = None
[docs] note: Optional[str] = None
@property
[docs] def lanelet_map(self) -> Optional[LaneletMap]: if self.lanelet_path is None: return None return load_lanelet_map(self.lanelet_path, origin=self.lanelet_map_origin)
@cached_property
[docs] def road_mesh(self) -> Optional[BirdviewMesh]: if self.mesh_path is None: if self.lanelet_path is None: return None else: lanelet_map = self.lanelet_map road_mesh = road_mesh_from_lanelet_map(lanelet_map) road_mesh = BirdviewMesh.set_properties(road_mesh, category='road').to(road_mesh.device) lane_mesh = lanelet_map_to_lane_mesh(lanelet_map, left_handed=False) combined_mesh = lane_mesh.merge(road_mesh) return combined_mesh else: return BirdviewMesh.load(self.mesh_path)
@property
[docs] def stoplines(self) -> List[Stopline]: if self.stoplines_path is None: return [] with open(self.stoplines_path, 'r') as f: stoplines = [Stopline(**d) for d in json.load(f)] return stoplines
@property
[docs] def traffic_light_controller(self) -> Optional[TrafficLightController]: if self.traffic_light_controller_path is None: return None return TrafficLightController.from_json(self.traffic_light_controller_path)
[docs]def _filename_defaults(name: str) -> Dict[str, str]: return dict( lanelet_path=f'{name}.osm', mesh_path=f'{name}_mesh.json', stoplines_path=f'{name}_stoplines.json', traffic_light_controller_path=f'{name}_traffic_light_controller.json', )
[docs]def resolve_paths_to_absolute(cfg: MapConfig, root: str) -> MapConfig: resolved_paths = dict() for pathname, default in _filename_defaults(cfg.name).items(): existing_path = getattr(cfg, pathname) if existing_path is None: existing_path = default if os.path.isabs(existing_path): continue candidate_path = os.path.join(root, existing_path) if os.path.exists(candidate_path): resolved_paths[pathname] = candidate_path resolved = dataclasses.replace(cfg, **resolved_paths) return resolved
[docs]def load_map_config(json_path: str, resolve_paths: bool = True) -> MapConfig: with open(json_path, 'r') as f: cfg = MapConfig(**json.load(f)) if resolve_paths: cfg = resolve_paths_to_absolute(cfg, os.path.dirname(json_path)) return cfg
[docs]def store_map_config(cfg: MapConfig, json_path: str, store_absolute_paths: bool = False) -> None: if not store_absolute_paths: cfg = dataclasses.replace( cfg, **{pathname: os.path.basename(getattr(cfg, pathname)) if getattr(cfg, pathname) is not None else None for pathname in _filename_defaults('').keys()} ) with open(json_path, 'w') as f: json.dump(dataclasses.asdict(cfg), f, indent=4)
[docs]def find_map_config(map_name: str, resolve_paths: bool = True) -> Optional[MapConfig]: """ To retrieve configs for maps not bundled with the package, folders with corresponding names must be placed inside one of the directories listed in TDS_RESOURCE_PATH environment variable. Note that map names should be globally unique. """ resource_path = torchdrivesim._resource_path for root in resource_path: map_path = os.path.join(root, map_name) if os.path.exists(map_path): break else: return None metadata_path = os.path.join(map_path, 'metadata.json') if os.path.exists(metadata_path): cfg = load_map_config(metadata_path) else: cfg = MapConfig(name=map_name) if resolve_paths: cfg = resolve_paths_to_absolute(cfg, root=map_path) return cfg
[docs]def download_iai_map(location_name: str, save_path: str) -> None: """ Downloads the map information using LOCATION_INFO from the Inverted AI API. IAI_API_KEY needs to be set for this to succeed. Basename of save_path will be used as the map name for TorchDriveSim. If the dirpath of save_path is in TDS_RESOURCE_PATH, the map will be immediately available in `find_map_config`. """ from invertedai import location_info info = location_info(location_name, include_map_source=True) os.makedirs(save_path, exist_ok=True) map_name = os.path.basename(save_path) pathname_defaults = _filename_defaults(map_name) center = info.map_center.x, info.map_center.y lanelet_path = os.path.join(save_path, pathname_defaults['lanelet_path']) info.osm_map.save_osm_file(lanelet_path) origin = info.osm_map.origin.x, info.osm_map.origin.y stoplines_path = os.path.join(save_path, pathname_defaults['stoplines_path']) stoplines = [ dataclasses.asdict(Stopline( actor_id=sa.actor_id, agent_type=sa.agent_type, x=sa.center.x, y=sa.center.y, length=sa.length, width=sa.width, orientation=sa.orientation, )) for sa in info.static_actors ] with open(stoplines_path, 'w') as f: json.dump(stoplines, f, indent=4) cfg = MapConfig( name=map_name, center=center, lanelet_map_origin=origin, iai_location_name=location_name, left_handed_coordinates=location_name.split(':')[0] == 'carla', # TODO: update once IAI API returns this info lanelet_path=os.path.abspath(lanelet_path), stoplines_path=os.path.abspath(stoplines_path), ) mesh_path = os.path.join(save_path, pathname_defaults['mesh_path']) cfg.road_mesh.save(mesh_path) cfg.mesh_path = os.path.abspath(mesh_path) store_map_config(cfg, os.path.join(save_path, 'metadata.json'))
[docs]def traffic_controls_from_map_config(cfg: MapConfig) -> Dict[str, BaseTrafficControl]: traffic_control_poses = { 'traffic_light': [], 'stop_sign': [], 'yield_sign': [], } for stopline in cfg.stoplines: if stopline.agent_type not in traffic_control_poses.keys(): continue pos = torch.tensor( [stopline.x, stopline.y, stopline.length, stopline.width, stopline.orientation], ) traffic_control_poses[stopline.agent_type].append(pos) traffic_controls = dict() if traffic_control_poses['traffic_light']: traffic_controls['traffic_light'] = TrafficLightControl( torch.stack(traffic_control_poses['traffic_light']).unsqueeze(0) ) if traffic_control_poses['stop_sign']: traffic_controls['stop_sign'] = StopSignControl( torch.stack(traffic_control_poses['stop_sign']).unsqueeze(0) ) if traffic_control_poses['yield_sign']: traffic_controls['yield_sign'] = YieldControl( torch.stack(traffic_control_poses['yield_sign']).unsqueeze(0) ) return traffic_controls
[docs]def find_wrong_way_stoplines(map_cfg: MapConfig, angle_threshold: float = np.pi/6) -> List[int]: wrong_way_stopline_ids = [] lanelet_map = map_cfg.lanelet_map if lanelet_map is None: return [] stoplines = map_cfg.stoplines for stopline in stoplines: lanelet_orientations = find_lanelet_directions(lanelet_map, stopline.x, stopline.y, lanelet_dist_tolerance=0) if lanelet_orientations and not any([abs(normalize_angle(psi - stopline.orientation)) < angle_threshold for psi in lanelet_orientations]): wrong_way_stopline_ids.append(stopline.actor_id) return wrong_way_stopline_ids