torchdrivesim.map

Classes

BirdviewMesh

2D mesh with vertices and faces assigned to discrete categories, to facilitate rendering simple 2D worlds.

BaseTrafficControl

Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state.

TrafficLightControl

Traffic lights, with default allowed states of ['red', 'yellow', 'green'].

StopSignControl

Stop sign, indicating the vehicle should stop and yield to cross traffic.

YieldControl

Yield sign, indicating that cross traffic has priority.

TrafficLightController

Stopline

MapConfig

Encapsulates various map metadata, including where to find files defining the map.

Functions

find_lanelet_directions(→ List[float])

For a given point, find local orientations of all lanelets that contain it.

load_lanelet_map() → LaneletMap)

Load a Lanelet2 map from an OSM file on disk.

road_mesh_from_lanelet_map(→ torchdrivesim.mesh.BaseMesh)

Creates a road mesh by triangulating all lanelets in a given map.

lanelet_map_to_lane_mesh(→ torchdrivesim.mesh.BirdviewMesh)

Creates a lane marking mesh from a given map.

normalize_angle(angle)

Normalize to <-pi, pi) range by shifting by a multiple of 2*pi.

_filename_defaults(→ Dict[str, str])

resolve_paths_to_absolute(→ MapConfig)

load_map_config(→ MapConfig)

store_map_config(→ None)

find_map_config(→ Optional[MapConfig])

To retrieve configs for maps not bundled with the package,

download_iai_map(→ None)

Downloads the map information using LOCATION_INFO from the Inverted AI API.

traffic_controls_from_map_config(→ Dict[str, ...)

find_wrong_way_stoplines(→ List[int])

Module Contents

torchdrivesim.map.find_lanelet_directions(lanelet_map: LaneletMap, x: float, y: float, tags_to_exclude: List[str] | None = None, lanelet_dist_tolerance: float = 1.0) List[float][source]

For a given point, find local orientations of all lanelets that contain it.

Parameters:
  • lanelet_map – the map with a lanelet layer

  • x – first coordinate of the point in the map’s coordinate frame

  • y – second coordinate of the point in the map’s coordinate frame

  • tags_to_exclude – lanelets tagged with any of those tags will be omitted as if they didn’t exist

  • lanelet_dist_tolerance – how far out of a lanelet a car can be for the lanelet to still be containing it

Returns:

for each lanelet, the angle representing its local orientation in radians

Raises:

Lanelet2NotFound – if lanelet2 package is not available

torchdrivesim.map.load_lanelet_map(map_path: str, origin: Tuple[float, float] = (0, 0)) LaneletMap[source]

Load a Lanelet2 map from an OSM file on disk.

Parameters:
  • map_path – local path to OSM file containing the map

  • origin – latitude and longitude of the origin to use with UTM projector

Raises:
  • Lanelet2NotFound – if lanelet2 package is not available

  • FileNotFoundError – if specified file doesn’t exist

torchdrivesim.map.road_mesh_from_lanelet_map(lanelet_map: LaneletMap, lanelets: List[int] | None = None) torchdrivesim.mesh.BaseMesh[source]

Creates a road mesh by triangulating all lanelets in a given map.

Parameters:
  • lanelet_map – map to triangulate

  • lanelets – if specified, only use lanelets with those ids

torchdrivesim.map.lanelet_map_to_lane_mesh(lanelet_map: LaneletMap, left_handed: bool = False, batch_size: int = 50000, left_right_marking_join_threshold: float = 0.1) torchdrivesim.mesh.BirdviewMesh[source]

Creates a lane marking mesh from a given map.

Parameters:
  • lanelet_map – map to use

  • left_handed – whether the map’s coordinate system is left-handed (flips the left and right boundary designations)

  • batch_size – controls the amount of points processed in parallel

  • left_right_marking_join_threshold – if left and right markings are this close, they will be treated as joint

class torchdrivesim.map.BirdviewMesh[source]

Bases: BaseMesh

2D mesh with vertices and faces assigned to discrete categories, to facilitate rendering simple 2D worlds. Category assignment is stored per vertex, and faces should not mix vertices from different categories. For each category there is a color and a rendering priority z (lower renders on top), but those don’t need to be specified until the mesh is converted to a different representation.

categories: List[str]
colors: Dict[str, torch.Tensor]
zs: Dict[str, float]
vert_category: torch.Tensor
_cat_fill: int = 0
_validate_input()[source]
property num_categories: int
classmethod set_properties(mesh: BaseMesh, category: str, color: Color | None = None, z: float | None = None)[source]

Lifts a BaseMesh into a BirdviewMesh with a single category.

to(device)[source]

Moves all tensors to a given device, returning a new mesh object. Does not modify the existing mesh object.

expand(size)[source]

Expands batch dimension on the right by the given factor, returning a new mesh.

select_batch_elements(idx)[source]

Selects given indices from batch dimension, possibly with repetitions, returning a new mesh

classmethod unify(meshes)[source]

Generates meshes equivalent to input meshes, only all sharing the same category definitions.

classmethod concat(meshes)[source]

Concatenates multiple meshes to form a single scene.

classmethod collate(meshes)[source]

Batches a collection of meshes with appropriate padding. All input meshes must have a singleton batch dimension.

fill_attr() RGBMesh[source]

Computes explicit color for each vertex and augments vertices with z values corresponding to rendering priority.

pytorch3d(include_textures=True) pytorch3d.structures.Meshes[source]

Converts the mesh to a PyTorch3D one. For the base class there are no textures, but subclasses may include them. Empty meshes are augmented with a single degenerate face on conversion, since PyTorch3D does not handle empty meshes correctly.

classmethod unpickle(mesh_file_path: str, pickle_module: Any = default_pickle_module)[source]

Load a mesh of this type from the given file.

serialize()[source]
classmethod _deserialize_tensors(data: Dict) Dict[source]

Convert list attributes to tensor attributes if there is any tensor attribute

classmethod empty(dim=2, batch_size=1)[source]

Create empty mesh.

trim(polygon: torch.Tensor, trim_face_only=False)[source]

Crops the mesh to a given 2D convex polygon, returning a new mesh. Faces where all vertices are outside the polygon are removed, even if the triangle intersects with the polygon. Vertices are removed if they are not used by any face.

Parameters:
  • polygon – BxPx2 tensor specifying a convex polygon in either clockwise or counterclockwise fashion

  • trim_face_only – whether to keep all vertices, including those unused by any faces

separate_by_category() Dict[str, BaseMesh][source]

Splits the mesh into meshes representing different categories.

class torchdrivesim.map.BaseTrafficControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Base traffic control class. All traffic controls are static rectangular stoplines with additional discrete state. States are advanced by calling step and can be set either from a given recorded history of states or by compute_state if current timestep exceeds the given recorded history.

Parameters:
  • pos – BxNx5 stoplines tensor [x, y, length/thickness, width, orientation]

  • allowed_states – Allowed values of state, e.g. colors for traffic lights.

  • replay_states – BxNxT tensor of states to replay, represented as indices into allowed_states. Default T=0.

  • mask – BxN boolean tensor indicating whether a given traffic control element is present and not a padding.

classmethod _default_allowed_states() List[str][source]
_default_replay_states() torch.Tensor[source]
_default_mask() torch.Tensor[source]
_default_state() torch.Tensor[source]
property total_replay_time: int
copy()[source]

Duplicates this object, allowing for independent subsequent execution.

to(device: torch.device)[source]

Modifies the object in-place, putting all tensors on the device provided.

extend(n: int, in_place: bool = True)[source]

Multiplies the first batch dimension by the given number. This is equivalent to introducing extra batch dimension on the right and then flattening.

select_batch_elements(idx: torch.Tensor, in_place: bool = True)[source]

Picks selected elements of the batch. The input is a tensor of indices into the batch dimension.

set_state(state: torch.Tensor) None[source]

Sets control state tensor.

compute_state(time: int) torch.Tensor[source]

Computes the state at a given time index. By default, it repeats the current state. Ignores replay states.

step(time: int) None[source]

Advances the state of the control given a time step. If the time value is within the range of replayable states then this state will be used otherwise the function compute_state(time) will be used to calculate a new state.

compute_violation(agent_state: torch.Tensor) torch.Tensor[source]

Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.

Parameters:

agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation

Returns:

BxA boolean tensor indicating which agents violated the traffic control

class torchdrivesim.map.TrafficLightControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Bases: BaseTrafficControl

Traffic lights, with default allowed states of [‘red’, ‘yellow’, ‘green’]. An agent is regarded to violate the traffic light if the light is red and the agent bounding box substantially overlaps with the stop line.

violation_rear_factor = 0.1
classmethod _default_allowed_states() List[str][source]
compute_violation(agent_state: torch.Tensor) torch.Tensor[source]

Given a collection of agents, computes which agents violate the traffic control. As deciding violations is typically complex, this function provides an approximate answer only. The base class reports no violations.

Parameters:

agent_state – BxAx5 tensor of agents states - x,y,length,width,orientation

Returns:

BxA boolean tensor indicating which agents violated the traffic control

class torchdrivesim.map.StopSignControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Bases: BaseTrafficControl

Stop sign, indicating the vehicle should stop and yield to cross traffic. Violations are not computed.

class torchdrivesim.map.YieldControl(pos: torch.Tensor, allowed_states: List[str] | None = None, replay_states: torch.Tensor | None = None, mask: torch.Tensor | None = None)[source]

Bases: BaseTrafficControl

Yield sign, indicating that cross traffic has priority. Violations are not computed.

class torchdrivesim.map.TrafficLightController(traffic_fsms: List[TrafficLightStateMachine])[source]
classmethod from_json(json_file_path: str) typing_extensions.Self[source]

Current json format: [

[
{
“actor_states”: {

“4411”: “red”, “3411”: “red”, ………

}, “state”: “0”, “duration”: 10, “next_state”: “1”

]

to_json() str[source]
tick(dt)[source]
set_to(light_states: List[List[float]])[source]
reset()[source]
update_current_state_and_time()[source]
property current_state
property current_state_with_name
property state_per_machine
property time_remaining
get_number_of_light_groups()[source]
collect_all_current_light_states()[source]
torchdrivesim.map.normalize_angle(angle)[source]

Normalize to <-pi, pi) range by shifting by a multiple of 2*pi. Works with floats, numpy arrays, and torch tensors.

class torchdrivesim.map.Stopline[source]
actor_id: int[source]
agent_type: str[source]
x: float[source]
y: float[source]
length: float[source]
width: float[source]
orientation: float[source]
__post_init__()[source]
class torchdrivesim.map.MapConfig[source]

Encapsulates various map metadata, including where to find files defining the map. Map definition includes a coordinate frame and traffic signals.

name: str[source]
left_handed_coordinates: bool = False[source]
center: Tuple[float, float] | None = None[source]
lanelet_path: str | None = None[source]
lanelet_map_origin: Tuple[float, float] = (0, 0)[source]
mesh_path: str | None = None[source]
stoplines_path: str | None = None[source]
traffic_light_controller_path: str | None = None[source]
iai_location_name: str | None = None[source]
note: str | None = None[source]
property lanelet_map: torchdrivesim.lanelet2.LaneletMap | None[source]
property road_mesh: torchdrivesim.mesh.BirdviewMesh | None[source]
property stoplines: List[Stopline][source]
property traffic_light_controller: torchdrivesim.traffic_lights.TrafficLightController | None[source]
torchdrivesim.map._filename_defaults(name: str) Dict[str, str][source]
torchdrivesim.map.resolve_paths_to_absolute(cfg: MapConfig, root: str) MapConfig[source]
torchdrivesim.map.load_map_config(json_path: str, resolve_paths: bool = True) MapConfig[source]
torchdrivesim.map.store_map_config(cfg: MapConfig, json_path: str, store_absolute_paths: bool = False) None[source]
torchdrivesim.map.find_map_config(map_name: str, resolve_paths: bool = True) MapConfig | None[source]

To retrieve configs for maps not bundled with the package, folders with corresponding names must be placed inside one of the directories listed in TDS_RESOURCE_PATH environment variable. Note that map names should be globally unique.

torchdrivesim.map.download_iai_map(location_name: str, save_path: str) None[source]

Downloads the map information using LOCATION_INFO from the Inverted AI API. IAI_API_KEY needs to be set for this to succeed. Basename of save_path will be used as the map name for TorchDriveSim. If the dirpath of save_path is in TDS_RESOURCE_PATH, the map will be immediately available in find_map_config.

torchdrivesim.map.traffic_controls_from_map_config(cfg: MapConfig) Dict[str, torchdrivesim.traffic_controls.BaseTrafficControl][source]
torchdrivesim.map.find_wrong_way_stoplines(map_cfg: MapConfig, angle_threshold: float = np.pi / 6) List[int][source]