AbstractAgent 详细文档
文件信息
- 路径:
navsim/agents/abstract_agent.py - 作用: 定义所有Agent的抽象基类接口
- 重要性: ⭐⭐⭐⭐⭐ (系统核心接口)
类继承关系
torch.nn.Module
│
└── AbstractAgent (ABC)
│
├── ConstantVelocityAgent
├── EgoStatusMLPAgent
├── TransfuserAgent
└── HumanAgent
核心属性
1. trajectory_sampling (TrajectorySampling)
- 类型:
TrajectorySampling - 作用: 定义轨迹的时间采样参数
- 包含:
- 时间跨度(如4秒)
- 采样频率(如10Hz)
2. requires_scene (bool)
- 类型:
bool - 默认值:
False - 作用: 标识Agent是否需要完整的场景信息
- 说明: 大部分Agent只需要AgentInput,特殊Agent(如HumanAgent)需要Scene
抽象方法(必须实现)
1. name() -> str
@abstractmethod
def name(self) -> str:
"""返回Agent的名称"""
- 作用: 返回Agent的标识名称
- 用途: 用于生成评估CSV文件名、日志记录等
- 示例返回: "ConstantVelocityAgent", "TransfuserAgent"
2. get_sensor_config() -> SensorConfig
@abstractmethod
def get_sensor_config(self) -> SensorConfig:
"""定义传感器配置"""
- 作用: 指定Agent需要哪些传感器数据
- 重要性: 直接影响数据加载性能
- 配置选项:
- 8个相机视角(cam_f0, cam_l0-l2, cam_r0-r2, cam_b0)
- LiDAR点云
- 历史帧数量
- 性能提示: 只加载必要的传感器可大幅提升速度
3. initialize() -> None
@abstractmethod
def initialize(self) -> None:
"""初始化Agent"""
- 调用时机: 推理前调用一次
- 典型操作:
- 加载模型权重
- 初始化网络结构
- 设置设备(GPU/CPU)
- 多进程注意: 每个worker会独立调用此方法
可选方法(学习型Agent)
1. forward(features: Dict) -> Dict
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""神经网络前向传播"""
- 输入: 特征字典(已批处理)
- 输出: 必须包含 "trajectory" 键
- 轨迹格式: [B, T, 3] (批次, 时间步, x/y/heading)
2. get_feature_builders() -> List[AbstractFeatureBuilder]
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
"""返回特征构建器列表"""
- 作用: 定义如何从AgentInput提取特征
- 常见构建器:
- EgoStatusFeatureBuilder: 提取车辆状态
- TransfuserFeatureBuilder: 处理图像和LiDAR
3. get_target_builders() -> List[AbstractTargetBuilder]
def get_target_builders(self) -> List[AbstractTargetBuilder]:
"""返回目标构建器列表"""
- 作用: 定义训练目标(ground truth)
- 访问权限: 可访问Scene中的真实轨迹
4. compute_loss(features, targets, predictions) -> torch.Tensor
def compute_loss(self, features, targets, predictions) -> torch.Tensor:
"""计算训练损失"""
- 输入: 特征、目标、预测
- 输出: 标量损失值
- 常见损失: L2距离、加权多任务损失
5. get_optimizers() -> Optimizer
def get_optimizers(self) -> Union[Optimizer, Dict]:
"""返回优化器配置"""
- 简单返回: 单个优化器
- 复杂返回: {"optimizer": opt, "lr_scheduler": scheduler}
核心方法实现
compute_trajectory(agent_input: AgentInput) -> Trajectory
def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
"""计算车辆轨迹(核心方法)"""
self.eval() # 设置为评估模式
features: Dict[str, torch.Tensor] = {}
# 1. 特征提取
for builder in self.get_feature_builders():
features.update(builder.compute_features(agent_input))
# 2. 添加批次维度
features = {k: v.unsqueeze(0) for k, v in features.items()}
# 3. 前向传播(无梯度)
with torch.no_grad():
predictions = self.forward(features)
poses = predictions["trajectory"].squeeze(0).numpy()
# 4. 构造轨迹对象
return Trajectory(poses, self._trajectory_sampling)
执行流程:
- 切换到评估模式(关闭dropout等)
- 通过特征构建器提取特征
- 添加批次维度(推理时batch_size=1)
- 执行前向传播获得轨迹预测
- 封装为Trajectory对象返回
数据流
AgentInput
↓
Feature Builders
↓
Features Dict
↓
Forward Pass
↓
Predictions Dict
↓
Trajectory Object
使用示例
创建简单Agent
class SimpleAgent(AbstractAgent):
def name(self) -> str:
return "SimpleAgent"
def get_sensor_config(self) -> SensorConfig:
return SensorConfig(lidar_pc=False, cameras=False)
def initialize(self) -> None:
pass # 无需初始化
def compute_trajectory(self, agent_input: AgentInput) -> Trajectory:
# 直接返回直行轨迹
poses = np.zeros((40, 3)) # 4秒, 10Hz
return Trajectory(poses, self._trajectory_sampling)
创建学习型Agent
class LearnedAgent(AbstractAgent):
def __init__(self, trajectory_sampling):
super().__init__(trajectory_sampling)
self.model = MyNeuralNetwork()
def forward(self, features: Dict) -> Dict:
trajectory = self.model(features["ego_status"])
return {"trajectory": trajectory}
def get_feature_builders(self):
return [EgoStatusFeatureBuilder()]
def compute_loss(self, features, targets, predictions):
return F.mse_loss(predictions["trajectory"], targets["trajectory"])
设计模式
1. 模板方法模式
compute_trajectory()定义了算法骨架- 子类通过实现抽象方法来定制行为
2. 策略模式
- 不同Agent实现不同的轨迹计算策略
- 运行时可替换不同的Agent实现
3. 依赖注入
- 通过构造函数注入
trajectory_sampling - 便于测试和配置
性能优化建议
-
传感器配置优化 - 只加载必要的传感器 - 减少历史帧数量 - 考虑降低图像分辨率
-
推理优化 - 使用
torch.no_grad()避免梯度计算 - 批处理多个场景 - 使用半精度(FP16)推理 -
缓存策略 - 缓存不变的特征 - 重用中间计算结果
扩展点
- 新传感器模态: 扩展SensorConfig支持新传感器
- 多模态融合: 在forward中实现复杂的融合策略
- 在线学习: 添加在线适应能力
- 不确定性估计: 输出轨迹的置信度
注意事项
- 线程安全: 多worker环境下每个worker独立初始化
- 设备管理: 注意GPU/CPU的正确设置
- 内存管理: 大批量数据时注意内存使用
- 异常处理: 实现健壮的错误处理机制