Skip to content

Commit

Permalink
Add shader uniforms test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Jan 18, 2024
1 parent cfec9f3 commit 3e30a9c
Show file tree
Hide file tree
Showing 14 changed files with 374 additions and 74 deletions.
2 changes: 1 addition & 1 deletion examples/occlusion/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def downgrade_waypoints(

def downgrade_vehicles(
center: Tuple[float, float],
neighborhood_vehicle_states: List[VehicleObservation],
neighborhood_vehicle_states: Tuple[VehicleObservation],
mode=ObservationOptions.multi_agent,
):
if mode:
Expand Down
9 changes: 7 additions & 2 deletions smarts/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import annotations

import warnings
from typing import NamedTuple, Sequence
from typing import TYPE_CHECKING, NamedTuple, Tuple

if TYPE_CHECKING:
from smarts.core.vehicle_state import Collision


class Events(NamedTuple):
"""Classified observations that can trigger agent done status."""

collisions: Sequence # Sequence[Collision]
collisions: Tuple[Collision]
"""Collisions with other vehicles (if any)."""
off_road: bool
"""True if vehicle is off the road, else False."""
Expand Down
127 changes: 127 additions & 0 deletions smarts/core/glsl/test_custom_shader_pass_shader.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#version 330 core
// This script is intended to test that all of the observation buffers work in association with
// the test script `test_renderers.py`.

uniform int step_count;
uniform int steps_completed;
uniform int events_off_road;
uniform int events_off_route;
uniform int events_on_shoulder;
uniform int events_wrong_way;
uniform int events_not_moving;
uniform int events_reached_goal;
uniform int events_reached_max_episode_steps;
uniform int events_agents_done_alive;
uniform int events_interest_done;
uniform int ego_vehicle_state_road_id;
uniform int ego_vehicle_state_lane_id;
uniform int ego_vehicle_state_lane_index;
uniform int under_this_vehicle_control;
uniform int vehicle_type;

uniform float dt;
uniform float ego_vehicle_state_heading;
uniform float ego_vehicle_state_speed;
uniform float ego_vehicle_state_steering;
uniform float ego_vehicle_state_yaw_rate;
uniform float elapsed_sim_time;
uniform float distance_travelled;

uniform vec3 ego_vehicle_state_position;
uniform vec3 ego_vehicle_state_bounding_box;
uniform vec3 ego_vehicle_state_lane_position;

uniform int neighborhood_vehicle_states_road_id[10];
uniform int neighborhood_vehicle_states_lane_id[10];
uniform int neighborhood_vehicle_states_lane_index[10];
uniform int neighborhood_vehicle_states_interest[10];
uniform int waypoint_paths_lane_id[10];
uniform int waypoint_paths_lane_index[10];
uniform int road_waypoints_lanes_lane_id[10];
uniform int road_waypoints_lanes_lane_index[10];
uniform int via_data_near_via_points_lane_index[10];
uniform int via_data_near_via_points_road_id[10];
uniform int via_data_near_via_points_hit[10];
uniform int lidar_point_cloud_hits[100];
uniform int signals_light_state[10];

uniform float neighborhood_vehicle_states_heading[10];
uniform float neighborhood_vehicle_states_speed[10];
uniform float waypoint_paths_heading[10];
uniform float waypoint_paths_lane_width[10];
uniform float waypoint_paths_speed_limit[10];
uniform float waypoint_paths_lane_offset[10];
uniform float road_waypoints_lanes_heading[10];
uniform float road_waypoints_lanes_width[10];
uniform float road_waypoints_lanes_speed_limit[10];
uniform float road_waypoints_lanes_lane_offset[10];
uniform float via_data_near_via_points_required_speed[10];
uniform float signals_last_changed[10];

uniform vec2 ego_vehicle_state_linear_velocity[10];
uniform vec2 ego_vehicle_state_angular_velocity[10];
uniform vec2 ego_vehicle_state_linear_acceleration[10];
uniform vec2 ego_vehicle_state_angular_acceleration[10];
uniform vec2 ego_vehicle_state_linear_jerk[10];
uniform vec2 ego_vehicle_state_angular_jerk[10];
uniform vec2 waypoint_paths_pos[10];
uniform vec2 road_waypoints_lanes_pos[10];
uniform vec2 via_data_near_via_points_position[10];
uniform vec2 signals_stop_point[10];

uniform vec3 neighborhood_vehicle_states_position[10];
uniform vec3 neighborhood_vehicle_states_bounding_box[10];
uniform vec3 neighborhood_vehicle_states_lane_position[10];
uniform vec3 lidar_point_cloud_points[100];
uniform vec3 lidar_point_cloud_origin[100];
uniform vec3 lidar_point_cloud_direction[100];
// SIGNALS_CONTROLLED_LANES = "signals_controlled_lanes"

// Output color
out vec4 p3d_Color;

uniform vec2 iResolution;


void mainImage( out vec4 fragColor, in vec2 fragCoord )
{
vec2 rec_res = 1.0 / iResolution.xy;
vec2 p = fragCoord.xy * rec_res;

fragColor = vec4(0.0, 0.0, 0.0, 0.0);
}

#ifndef SHADERTOY
void main(){
int a = step_count + steps_completed + events_off_road + events_off_route
+ events_on_shoulder + events_wrong_way + events_not_moving
+ events_reached_goal + events_reached_max_episode_steps
+ events_agents_done_alive + events_interest_done + ego_vehicle_state_road_id
+ ego_vehicle_state_lane_id + ego_vehicle_state_lane_index + under_this_vehicle_control + vehicle_type;
float b = dt + ego_vehicle_state_heading + ego_vehicle_state_speed
+ ego_vehicle_state_steering + ego_vehicle_state_yaw_rate + elapsed_sim_time
+ distance_travelled;
vec3 c = ego_vehicle_state_position + ego_vehicle_state_bounding_box + ego_vehicle_state_lane_position;
int d = neighborhood_vehicle_states_road_id[0] + neighborhood_vehicle_states_lane_id[0]
+ neighborhood_vehicle_states_lane_index[0] + neighborhood_vehicle_states_interest[0]
+ waypoint_paths_lane_id[0] + waypoint_paths_lane_index[0] + road_waypoints_lanes_lane_id[0]
+ road_waypoints_lanes_lane_index[0] + via_data_near_via_points_lane_index[0]
+ via_data_near_via_points_road_id[0] + via_data_near_via_points_hit[0]
+ lidar_point_cloud_hits[0] + signals_light_state[0];
float e = neighborhood_vehicle_states_heading[0] + neighborhood_vehicle_states_speed[0]
+ waypoint_paths_heading[0] + waypoint_paths_lane_width[0] + waypoint_paths_speed_limit[0]
+ waypoint_paths_lane_offset[0] + road_waypoints_lanes_heading[0] + road_waypoints_lanes_width[0]
+ road_waypoints_lanes_speed_limit[0] + road_waypoints_lanes_lane_offset[0]
+ via_data_near_via_points_required_speed[0] + signals_last_changed[0];
vec2 f = ego_vehicle_state_linear_velocity[0] + ego_vehicle_state_angular_velocity[0]
+ ego_vehicle_state_linear_acceleration[0] + ego_vehicle_state_angular_acceleration[0]
+ ego_vehicle_state_linear_jerk[0] + ego_vehicle_state_angular_jerk[0]
+ waypoint_paths_pos[0] + road_waypoints_lanes_pos[0] + via_data_near_via_points_position[0]
+ signals_stop_point[0];
vec3 g = neighborhood_vehicle_states_position[0] + neighborhood_vehicle_states_bounding_box[0]
+ neighborhood_vehicle_states_lane_position[0] + lidar_point_cloud_points[0]
+ lidar_point_cloud_origin[0] + lidar_point_cloud_direction[0];

mainImage( p3d_Color, gl_FragCoord.xy );
}
#endif
54 changes: 48 additions & 6 deletions smarts/core/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import numpy as np

from smarts.core.utils.cache import cache

if TYPE_CHECKING:
from smarts.core import plan, signals
from smarts.core.coordinates import Dimensions, Heading, Point, RefLinePoint
Expand Down Expand Up @@ -106,6 +108,9 @@ class RoadWaypoints(NamedTuple):
lanes: Dict[str, List[List[Waypoint]]]
"""Mapping of road ids to their lane waypoints."""

def __hash__(self) -> int:
return hash(tuple((k, len(v)) for k, v in self.lanes.items()))


class GridMapMetadata(NamedTuple):
"""Map grid metadata."""
Expand All @@ -130,6 +135,9 @@ class TopDownRGB(NamedTuple):
data: np.ndarray
"""A RGB image with the ego vehicle at the center."""

def __hash__(self) -> int:
return self.metadata.__hash__()


class OccupancyGridMap(NamedTuple):
"""Occupancy map."""
Expand All @@ -141,6 +149,9 @@ class OccupancyGridMap(NamedTuple):
See https://en.wikipedia.org/wiki/Occupancy_grid_mapping."""

def __hash__(self) -> int:
return self.metadata.__hash__()


class ObfuscationGridMap(NamedTuple):
"""Obfuscation map."""
Expand All @@ -150,6 +161,9 @@ class ObfuscationGridMap(NamedTuple):
data: np.ndarray
"""A map showing what is visible from the ego vehicle"""

def __hash__(self) -> int:
return self.metadata.__hash__()


class DrivableAreaGridMap(NamedTuple):
"""Drivable area map."""
Expand All @@ -159,6 +173,9 @@ class DrivableAreaGridMap(NamedTuple):
data: np.ndarray
"""A grid map that shows the static drivable area around the ego vehicle."""

def __hash__(self) -> int:
return self.metadata.__hash__()


class CustomRenderData(NamedTuple):
"""Describes information about a custom render."""
Expand All @@ -168,6 +185,9 @@ class CustomRenderData(NamedTuple):
data: np.ndarray
"""The image data from the render."""

def __hash__(self) -> int:
return self.metadata.__hash__()


class ViaPoint(NamedTuple):
"""'Collectibles' that can be placed within the simulation."""
Expand All @@ -187,13 +207,13 @@ class ViaPoint(NamedTuple):
class Vias(NamedTuple):
"""Listing of nearby collectible ViaPoints and ViaPoints collected in the last step."""

near_via_points: List[ViaPoint]
near_via_points: Tuple[ViaPoint]
"""Ordered list of nearby points that have not been hit."""

@property
def hit_via_points(self) -> List[ViaPoint]:
def hit_via_points(self) -> Tuple[ViaPoint]:
"""List of points that were hit in the previous step."""
return [vp for vp in self.near_via_points if vp.hit]
return tuple(vp for vp in self.near_via_points if vp.hit)


class SignalObservation(NamedTuple):
Expand All @@ -204,7 +224,7 @@ class SignalObservation(NamedTuple):
stop_point: Point
"""The stopping point for traffic controlled by the signal, i.e., the
point where actors should stop when the signal is in a stop state."""
controlled_lanes: List[str]
controlled_lanes: Tuple[str]
"""If known, the lane_ids of all lanes controlled-by this signal.
May be empty if this is not easy to determine."""
last_changed: Optional[float]
Expand All @@ -228,7 +248,7 @@ class Observation(NamedTuple):
"""Ego vehicle status."""
under_this_agent_control: bool
"""Whether this agent currently has control of the vehicle."""
neighborhood_vehicle_states: Optional[List[VehicleObservation]]
neighborhood_vehicle_states: Optional[Tuple[VehicleObservation]]
"""List of neighborhood vehicle states."""
waypoint_paths: Optional[List[List[Waypoint]]]
"""Dynamic evenly-spaced points on the road ahead of the vehicle, showing potential routes ahead."""
Expand All @@ -250,9 +270,31 @@ class Observation(NamedTuple):
"""Occupancy map."""
top_down_rgb: Optional[TopDownRGB] = None
"""RGB camera observation."""
signals: Optional[List[SignalObservation]] = None
signals: Optional[Tuple[SignalObservation]] = None
"""List of nearby traffic signal (light) states on this time-step."""
obfuscation_grid_map: Optional[ObfuscationGridMap] = None
"""Observable area map."""
custom_renders: Tuple[CustomRenderData, ...] = tuple()
"""Custom renders."""

def __hash__(self):
return hash(
(
self.dt,
self.step_count,
self.elapsed_sim_time,
self.events,
self.ego_vehicle_state,
# self.waypoint_paths, # likely redundant
self.neighborhood_vehicle_states,
self.distance_travelled,
self.road_waypoints,
self.via_data,
self.drivable_area_grid_map,
self.occupancy_grid_map,
self.top_down_rgb,
self.signals,
self.obfuscation_grid_map,
self.custom_renders,
)
)
13 changes: 13 additions & 0 deletions smarts/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ def from_pose(cls, pose: Pose):
from_front_bumper=False,
)

def __hash__(self):
hash_ = getattr(self, "hash", None)
if not hash_:
hash_ = hash(
(
tuple(self.position),
self.heading,
self.from_front_bumper,
)
)
object.__setattr__(self, "hash", hash_)
return hash_


@dataclass(frozen=True, unsafe_hash=True)
class Goal:
Expand Down
6 changes: 3 additions & 3 deletions smarts/core/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def closest_point_on_lane(position, lane_id, road_map):
self._consumed_via_points.add(via)

near_points.sort(key=lambda dist, _: dist)
return [p for _, p in near_points]
return tuple(p for _, p in near_points)

def teardown(self, **kwargs):
pass
Expand Down Expand Up @@ -1098,12 +1098,12 @@ def __call__(
SignalObservation(
state=signal_state.state,
stop_point=signal_state.stopping_pos,
controlled_lanes=controlled_lanes,
controlled_lanes=tuple(controlled_lanes),
last_changed=signal_state.last_changed,
)
)

return result
return tuple(result)

def _find_signals_ahead(
self,
Expand Down
8 changes: 5 additions & 3 deletions smarts/core/sensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def process_serialization_safe_sensors(
else None
)

near_via_points = []
near_via_points = ()

via_sensor = vehicle_sensors.get("via_sensor")
if via_sensor:
Expand Down Expand Up @@ -498,7 +498,7 @@ def process_serialization_safe_sensors(
events=events,
ego_vehicle_state=ego_vehicle,
under_this_agent_control=agent_controls,
neighborhood_vehicle_states=neighborhood_vehicle_states,
neighborhood_vehicle_states=neighborhood_vehicle_states or (),
waypoint_paths=waypoint_paths,
distance_travelled=distance_travelled,
road_waypoints=road_waypoints,
Expand Down Expand Up @@ -667,7 +667,9 @@ def _is_done_with_events(
)

events = Events(
collisions=sim_frame.filtered_vehicle_collisions(vehicle_state.actor_id),
collisions=tuple(
sim_frame.filtered_vehicle_collisions(vehicle_state.actor_id)
),
off_road=is_off_road,
reached_goal=reached_goal,
reached_max_episode_steps=reached_max_episode_steps,
Expand Down
7 changes: 4 additions & 3 deletions smarts/core/shader_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class BufferID(Enum):
STEPS_COMPLETED = "steps_completed"
ELAPSED_SIM_TIME = "elapsed_sim_time"

EVENTS_COLLISIONS = "events_collisions"
EVENTS_OFF_ROAD = "events_off_road"
EVENTS_OFF_ROUTE = "events_off_route"
EVENTS_ON_SHOULDER = "events_on_shoulder"
Expand Down Expand Up @@ -94,8 +95,8 @@ class BufferID(Enum):

ROAD_WAYPOINTS_POSITION = "road_waypoints_lanes_pos"
ROAD_WAYPOINTS_HEADING = "road_waypoints_lanes_heading"
ROAD_WAYPOINTS_LANE_ID = "road_waypoints_lane_id"
ROAD_WAYPOINTS_LANE_WIDTH = "road_waypoints_lane_width"
ROAD_WAYPOINTS_LANE_ID = "road_waypoints_lanes_lane_id"
ROAD_WAYPOINTS_LANE_WIDTH = "road_waypoints_lanes_width"
ROAD_WAYPOINTS_SPEED_LIMIT = "road_waypoints_lanes_speed_limit"
ROAD_WAYPOINTS_LANE_INDEX = "road_waypoints_lanes_lane_index"
ROAD_WAYPOINTS_LANE_OFFSET = "road_waypoints_lanes_lane_offset"
Expand All @@ -115,5 +116,5 @@ class BufferID(Enum):

SIGNALS_LIGHT_STATE = "signals_light_state"
SIGNALS_STOP_POINT = "signals_stop_point"
SIGNALS_CONTROLLED_LANES = "signals_controlled_lanes"
# SIGNALS_CONTROLLED_LANES = "signals_controlled_lanes"
SIGNALS_LAST_CHANGED = "signals_last_changed"
Loading

0 comments on commit 3e30a9c

Please sign in to comment.