import os
from typing_extensions import Self
import numpy as np
import pandas as pd
import torch
from torchdrivesim.behavior.common import InitializationFailedError
from torchdrivesim.simulator import NPCWrapper, SimulatorInterface, TensorPerAgentType
from torchdrivesim.utils import assert_equal
[docs]def interaction_replay(location, dataset_path, initial_frame=1, segment_length=40, recording=0):
recording_path = os.path.join(dataset_path, 'recorded_trackfiles', location, 'vehicle_tracks_{:03d}.csv'.format(recording))
df = pd.read_csv(recording_path)
final_frame = initial_frame + segment_length - 1
for frame in [initial_frame, final_frame]:
if frame not in df.frame_id.unique():
raise InitializationFailedError(f'Frame {frame} not available in {recording_path}')
df = df[(df.frame_id >= initial_frame) & (df.frame_id <= final_frame)].copy()
df = df.sort_values(['track_id', 'frame_id'])
df['rear_offset'] = 1.4
agent_ids = sorted(df.track_id.unique())
agent_attributes = []
for agent_id in agent_ids:
attr = df[df.track_id == agent_id][['length', 'width', 'rear_offset']].to_numpy()
attr = torch.from_numpy(attr).mean(dim=-2)
agent_attributes.append(attr)
agent_attributes = torch.stack(agent_attributes, dim=-2).unsqueeze(0)
df['present'] = True
df['speed'] = np.sqrt(df.vx ** 2 + df.vy ** 2)
frame_ids = sorted(df.frame_id.unique())
dense_index = pd.MultiIndex.from_product([agent_ids, frame_ids], names=["track_id", "frame_id"])
padding = pd.DataFrame(index=dense_index, data=dict(x=0.0, y=0.0, psi_rad=0.0, speed=0.0, present=False))
daug = df.set_index(['track_id', 'frame_id']).reindex(dense_index).combine_first(padding)
agent_states = torch.from_numpy(daug[['x', 'y', 'psi_rad', 'speed']].to_numpy()).unsqueeze(0)
agent_states = agent_states.reshape(1, len(agent_ids), len(frame_ids), 4)
present_mask = torch.from_numpy(daug['present'].astype(bool).to_numpy()).unsqueeze(0)
present_mask = present_mask.reshape(1, len(agent_ids), len(frame_ids))
return agent_attributes, agent_states, present_mask
[docs]class ReplayWrapper(NPCWrapper):
"""
Performs log replay for a subset of agents based on their recorded trajectories.
The log has a finite length T, after which the replay will loop back from the beginning.
Args:
simulator: existing simulator to wrap
npc_mask: A functor of tensors with a single dimension of size A, indicating which agents to replay.
The tensors can not have batch dimensions.
agent_states: a functor of BxAxTxSt tensors with states to replay across time,
which should be padded with arbitrary values for non-replay agents
present_masks: indicates when replay agents appear and disappear; by default they're all present at all times
time: initial index into the time dimension for replay, incremented at every step
"""
def __init__(self, simulator: SimulatorInterface, npc_mask: TensorPerAgentType,
agent_states: TensorPerAgentType, present_masks: TensorPerAgentType = None, time: int = 0):
super().__init__(simulator=simulator, npc_mask=npc_mask)
self.replay_states = agent_states
self.present_masks = present_masks
self.time = time
self.max_time_step = max(self.across_agent_types(lambda x: x.shape[-2], agent_states).values())
if self.present_masks is None:
# by default all replay agents are always present
self.present_masks = self.across_agent_types(
lambda states: torch.ones_like(states[..., 0], dtype=torch.bool), self.replay_states
)
self.validate_agent_types()
self.validate_tensor_shapes()
[docs] def _npc_teleport_to(self):
current_replay_state = self.across_agent_types(
lambda states: states[..., self.time, :], self.replay_states
)
return current_replay_state
[docs] def _update_npc_present_mask(self):
return self.across_agent_types(lambda pm: pm[..., self.time], self.present_masks)
[docs] def step(self, action):
self.time += 1
# reset time if needed
if self.time == self.max_time_step:
self.time = 0
super().step(action)
[docs] def to(self, device) -> Self:
super().to(device)
self.replay_states = self.agent_functor.to_device(self.replay_states, device)
self.present_masks = self.agent_functor.to_device(self.present_masks, device)
[docs] def copy(self):
inner_copy = self.inner_simulator.copy()
other = self.__class__(inner_copy, npc_mask=self.npc_mask, agent_states=self.replay_states,
present_masks=self.present_masks, time=self.time)
return other
[docs] def extend(self, n, in_place=True):
if not in_place:
return super().extend(n, in_place=in_place)
self.inner_simulator.extend(n)
enlarge = lambda x: x.unsqueeze(1).expand((x.shape[0], n) + x.shape[1:]).reshape((n * x.shape[0],) + x.shape[1:])
self.replay_states = self.across_agent_types(enlarge, self.replay_states)
self.present_masks = self.across_agent_types(enlarge, self.present_masks)
[docs] def select_batch_elements(self, idx, in_place=True):
other = super().select_batch_elements(idx, in_place=in_place)
other.replay_states = other.across_agent_types(lambda x: x[idx], other.replay_states)
other.present_masks = other.across_agent_types(lambda x: x[idx], other.present_masks)
other._batch_size = len(idx)
return other
[docs] def validate_agent_types(self):
assert list(self.npc_mask.keys()) == self.agent_types
assert list(self.replay_states.keys()) == self.agent_types
assert list(self.present_masks.keys()) == self.agent_types
[docs] def validate_tensor_shapes(self):
# check that tensors have the expected number of dimensions
self.across_agent_types(lambda m: assert_equal(len(m.shape), 1), self.npc_mask)
self.across_agent_types(lambda s: assert_equal(len(s.shape), 4), self.replay_states)
self.across_agent_types(lambda m: assert_equal(len(m.shape), 3), self.present_masks)
# check that batch size is the same everywhere
b = self.batch_size
self.across_agent_types(lambda s: assert_equal(s.shape[0], b), self.replay_states)
self.across_agent_types(lambda m: assert_equal(m.shape[0], b), self.present_masks)
# check that the number of agents in replay is the same as in underlying simulator
check_counts = lambda i: lambda x, y: assert_equal(x.shape[i], y)
self.across_agent_types(check_counts(0), self.npc_mask, self.inner_simulator.agent_count)
self.across_agent_types(check_counts(-3), self.replay_states, self.inner_simulator.agent_count)
self.across_agent_types(check_counts(-2), self.present_masks, self.inner_simulator.agent_count)