import dataclasses
import json
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List
from functools import cached_property
import lanelet2
import numpy as np
import torch
import torchdrivesim
from torchdrivesim.lanelet2 import LaneletMap, find_lanelet_directions, load_lanelet_map, road_mesh_from_lanelet_map, lanelet_map_to_lane_mesh
from torchdrivesim.mesh import BirdviewMesh
from torchdrivesim.traffic_controls import BaseTrafficControl, TrafficLightControl, StopSignControl, YieldControl
from torchdrivesim.traffic_lights import TrafficLightController
from torchdrivesim.utils import normalize_angle
@dataclass
[docs]class Stopline:
[docs] def __post_init__(self):
if self.agent_type == "traffic-light":
self.agent_type = "traffic_light"
elif self.agent_type == "stop-sign":
self.agent_type = "stop_sign"
elif self.agent_type in ["yield-sign", "yield"]:
self.agent_type = "yield_sign"
@dataclass
[docs]class MapConfig:
"""
Encapsulates various map metadata, including where to find files defining the map.
Map definition includes a coordinate frame and traffic signals.
"""
[docs] left_handed_coordinates: bool = False
[docs] center: Optional[Tuple[float, float]] = None
[docs] lanelet_path: Optional[str] = None
[docs] lanelet_map_origin: Tuple[float, float] = (0, 0)
[docs] mesh_path: Optional[str] = None
[docs] stoplines_path: Optional[str] = None
[docs] traffic_light_controller_path: Optional[str] = None
[docs] iai_location_name: Optional[str] = None
[docs] note: Optional[str] = None
@property
[docs] def lanelet_map(self) -> Optional[LaneletMap]:
if self.lanelet_path is None:
return None
return load_lanelet_map(self.lanelet_path, origin=self.lanelet_map_origin)
@cached_property
[docs] def road_mesh(self) -> Optional[BirdviewMesh]:
if self.mesh_path is None:
if self.lanelet_path is None:
return None
else:
lanelet_map = self.lanelet_map
road_mesh = road_mesh_from_lanelet_map(lanelet_map)
road_mesh = BirdviewMesh.set_properties(road_mesh, category='road').to(road_mesh.device)
lane_mesh = lanelet_map_to_lane_mesh(lanelet_map, left_handed=False)
combined_mesh = lane_mesh.merge(road_mesh)
return combined_mesh
else:
return BirdviewMesh.load(self.mesh_path)
@property
[docs] def stoplines(self) -> List[Stopline]:
if self.stoplines_path is None:
return []
with open(self.stoplines_path, 'r') as f:
stoplines = [Stopline(**d) for d in json.load(f)]
return stoplines
@property
[docs] def traffic_light_controller(self) -> Optional[TrafficLightController]:
if self.traffic_light_controller_path is None:
return None
return TrafficLightController.from_json(self.traffic_light_controller_path)
[docs]def _filename_defaults(name: str) -> Dict[str, str]:
return dict(
lanelet_path=f'{name}.osm',
mesh_path=f'{name}_mesh.json',
stoplines_path=f'{name}_stoplines.json',
traffic_light_controller_path=f'{name}_traffic_light_controller.json',
)
[docs]def resolve_paths_to_absolute(cfg: MapConfig, root: str) -> MapConfig:
resolved_paths = dict()
for pathname, default in _filename_defaults(cfg.name).items():
existing_path = getattr(cfg, pathname)
if existing_path is None:
existing_path = default
if os.path.isabs(existing_path):
continue
candidate_path = os.path.join(root, existing_path)
if os.path.exists(candidate_path):
resolved_paths[pathname] = candidate_path
resolved = dataclasses.replace(cfg, **resolved_paths)
return resolved
[docs]def load_map_config(json_path: str, resolve_paths: bool = True) -> MapConfig:
with open(json_path, 'r') as f:
cfg = MapConfig(**json.load(f))
if resolve_paths:
cfg = resolve_paths_to_absolute(cfg, os.path.dirname(json_path))
return cfg
[docs]def store_map_config(cfg: MapConfig, json_path: str, store_absolute_paths: bool = False) -> None:
if not store_absolute_paths:
cfg = dataclasses.replace(
cfg, **{pathname: os.path.basename(getattr(cfg, pathname)) if getattr(cfg, pathname) is not None else None
for pathname in _filename_defaults('').keys()}
)
with open(json_path, 'w') as f:
json.dump(dataclasses.asdict(cfg), f, indent=4)
[docs]def find_map_config(map_name: str, resolve_paths: bool = True) -> Optional[MapConfig]:
"""
To retrieve configs for maps not bundled with the package,
folders with corresponding names must be placed inside one of the directories
listed in TDS_RESOURCE_PATH environment variable.
Note that map names should be globally unique.
"""
resource_path = torchdrivesim._resource_path
for root in resource_path:
map_path = os.path.join(root, map_name)
if os.path.exists(map_path):
break
else:
return None
metadata_path = os.path.join(map_path, 'metadata.json')
if os.path.exists(metadata_path):
cfg = load_map_config(metadata_path)
else:
cfg = MapConfig(name=map_name)
if resolve_paths:
cfg = resolve_paths_to_absolute(cfg, root=map_path)
return cfg
[docs]def download_iai_map(location_name: str, save_path: str) -> None:
"""
Downloads the map information using LOCATION_INFO from the Inverted AI API.
IAI_API_KEY needs to be set for this to succeed.
Basename of save_path will be used as the map name for TorchDriveSim.
If the dirpath of save_path is in TDS_RESOURCE_PATH, the map will be immediately
available in `find_map_config`.
"""
from invertedai import location_info
info = location_info(location_name, include_map_source=True)
os.makedirs(save_path, exist_ok=True)
map_name = os.path.basename(save_path)
pathname_defaults = _filename_defaults(map_name)
center = info.map_center.x, info.map_center.y
lanelet_path = os.path.join(save_path, pathname_defaults['lanelet_path'])
info.osm_map.save_osm_file(lanelet_path)
origin = info.osm_map.origin.x, info.osm_map.origin.y
stoplines_path = os.path.join(save_path, pathname_defaults['stoplines_path'])
stoplines = [
dataclasses.asdict(Stopline(
actor_id=sa.actor_id, agent_type=sa.agent_type, x=sa.center.x, y=sa.center.y,
length=sa.length, width=sa.width, orientation=sa.orientation,
)) for sa in info.static_actors
]
with open(stoplines_path, 'w') as f:
json.dump(stoplines, f, indent=4)
cfg = MapConfig(
name=map_name, center=center, lanelet_map_origin=origin,
iai_location_name=location_name,
left_handed_coordinates=location_name.split(':')[0] == 'carla', # TODO: update once IAI API returns this info
lanelet_path=os.path.abspath(lanelet_path),
stoplines_path=os.path.abspath(stoplines_path),
)
mesh_path = os.path.join(save_path, pathname_defaults['mesh_path'])
cfg.road_mesh.save(mesh_path)
cfg.mesh_path = os.path.abspath(mesh_path)
store_map_config(cfg, os.path.join(save_path, 'metadata.json'))
[docs]def traffic_controls_from_map_config(cfg: MapConfig) -> Dict[str, BaseTrafficControl]:
traffic_control_poses = {
'traffic_light': [],
'stop_sign': [],
'yield_sign': [],
}
for stopline in cfg.stoplines:
if stopline.agent_type not in traffic_control_poses.keys():
continue
pos = torch.tensor(
[stopline.x, stopline.y, stopline.length, stopline.width, stopline.orientation],
)
traffic_control_poses[stopline.agent_type].append(pos)
traffic_controls = dict()
if traffic_control_poses['traffic_light']:
traffic_controls['traffic_light'] = TrafficLightControl(
torch.stack(traffic_control_poses['traffic_light']).unsqueeze(0)
)
if traffic_control_poses['stop_sign']:
traffic_controls['stop_sign'] = StopSignControl(
torch.stack(traffic_control_poses['stop_sign']).unsqueeze(0)
)
if traffic_control_poses['yield_sign']:
traffic_controls['yield_sign'] = YieldControl(
torch.stack(traffic_control_poses['yield_sign']).unsqueeze(0)
)
return traffic_controls
[docs]def find_wrong_way_stoplines(map_cfg: MapConfig, angle_threshold: float = np.pi/6) -> List[int]:
wrong_way_stopline_ids = []
lanelet_map = map_cfg.lanelet_map
if lanelet_map is None:
return []
stoplines = map_cfg.stoplines
for stopline in stoplines:
lanelet_orientations = find_lanelet_directions(lanelet_map, stopline.x, stopline.y, lanelet_dist_tolerance=0)
if lanelet_orientations and not any([abs(normalize_angle(psi - stopline.orientation)) < angle_threshold for psi in lanelet_orientations]):
wrong_way_stopline_ids.append(stopline.actor_id)
return wrong_way_stopline_ids