feat: Implement Spatial Graphs as Signal Processing Layer
This commit is contained in:
469
spatial_graph.py
Normal file
469
spatial_graph.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Spatial Graph System for hierarchical location management and perception propagation.
|
||||
|
||||
Implements:
|
||||
- Hierarchical location graph (world → region → location → POI)
|
||||
- Perception bubble-up algorithm with LLM-powered filtering
|
||||
- Portal-based information propagation (vision/sound)
|
||||
- Spatial metadata for entity positioning within locations
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Optional, Dict, List, Tuple, Any
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PropagationType(Enum):
|
||||
"""Types of information that propagate through space."""
|
||||
|
||||
VISION = "vision"
|
||||
SOUND = "sound"
|
||||
ACTION = "action"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortalConnection:
|
||||
"""Represents a connection between two spatial nodes with propagation properties."""
|
||||
|
||||
target: str # Target location_id
|
||||
portal: Optional[str] = None # Portal type (door, window, peephole, etc.)
|
||||
portal_state_descriptor: Optional[str] = (
|
||||
None # Current state of portal (open, closed, tinted, etc.)
|
||||
)
|
||||
vision_prop: int = 5 # 0-10 scale: 0=no vision, 10=clear vision
|
||||
sound_prop: int = 5 # 0-10 scale: 0=no sound, 10=all sound propagates
|
||||
bidirectional: bool = True # Whether connection is symmetric
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerceptionInfo:
|
||||
"""Filtered/transformed perception for an entity."""
|
||||
|
||||
recipient_id: str
|
||||
original_action: str
|
||||
transformed_action: Optional[str] # None if blocked
|
||||
propagation_path: str # e.g., "direct", "through_wall", "muffled_from_distance"
|
||||
vision_clarity: int # 0-10
|
||||
sound_clarity: int # 0-10
|
||||
perceivable: bool # Whether entity perceives anything
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpatialNode:
|
||||
"""A node in the spatial hierarchy (world, region, location, POI)."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
node_type: str # "world", "region", "location", "poi"
|
||||
parent_id: Optional[str] = None
|
||||
children: List[str] = field(default_factory=list)
|
||||
connections: List[PortalConnection] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data["connections"] = [c.to_dict() for c in self.connections]
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityPosition:
|
||||
"""Spatial metadata for an entity's position."""
|
||||
|
||||
entity_id: str
|
||||
location_id: str # Leaf-level location (POI or location)
|
||||
spatial_descriptor: str # "Leaning against the far left edge, near the kegs"
|
||||
|
||||
|
||||
class SpatialGraph:
|
||||
"""
|
||||
Manages hierarchical spatial structure and perception propagation.
|
||||
|
||||
Hierarchy: World → Region → Location → POI (Point of Interest)
|
||||
|
||||
Entities have:
|
||||
- location_id: Which leaf node they're in
|
||||
- spatial_descriptor: Precise position within that node
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.nodes: Dict[str, SpatialNode] = {}
|
||||
self.entity_positions: Dict[str, EntityPosition] = {}
|
||||
self.node_hierarchy: Dict[str, List[str]] = {} # location_id → path to root
|
||||
|
||||
def load_from_json(self, data: Dict[str, Any]) -> None:
|
||||
"""Load spatial graph from JSON structure (from demo.json)."""
|
||||
logger.info("Loading spatial graph from JSON")
|
||||
self._build_hierarchy(data["world"], None)
|
||||
logger.info(f"Spatial graph loaded: {len(self.nodes)} nodes")
|
||||
|
||||
def _build_hierarchy(
|
||||
self, node_data: Dict[str, Any], parent_id: Optional[str]
|
||||
) -> None:
|
||||
"""Recursively build spatial hierarchy from JSON."""
|
||||
node_id = node_data.get("id")
|
||||
node_type = self._determine_node_type(parent_id)
|
||||
|
||||
connections = []
|
||||
for conn_data in node_data.get("connections", []):
|
||||
conn = PortalConnection(
|
||||
target=conn_data["target"],
|
||||
portal=conn_data.get("portal"),
|
||||
portal_state_descriptor=conn_data.get("portal_state_descriptor"),
|
||||
vision_prop=conn_data.get("vision_prop", 5),
|
||||
sound_prop=conn_data.get("sound_prop", 5),
|
||||
bidirectional=conn_data.get("bidirectional", True),
|
||||
)
|
||||
connections.append(conn)
|
||||
|
||||
node = SpatialNode(
|
||||
id=node_id,
|
||||
name=node_data.get("name", node_id),
|
||||
description=node_data.get("description", ""),
|
||||
node_type=node_type,
|
||||
parent_id=parent_id,
|
||||
connections=connections,
|
||||
)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
# Process children
|
||||
children_key = self._get_children_key(node_type)
|
||||
for child_data in node_data.get(children_key, []):
|
||||
child_id = child_data.get("id")
|
||||
node.children.append(child_id)
|
||||
self._build_hierarchy(child_data, node_id)
|
||||
|
||||
def _determine_node_type(self, parent_id: Optional[str]) -> str:
|
||||
"""Determine node type based on parent."""
|
||||
if parent_id is None:
|
||||
return "world"
|
||||
parent = self.nodes.get(parent_id)
|
||||
if parent.node_type == "world":
|
||||
return "region"
|
||||
elif parent.node_type == "region":
|
||||
return "location"
|
||||
else:
|
||||
return "poi"
|
||||
|
||||
def _get_children_key(self, node_type: str) -> str:
|
||||
"""Get the key for children based on node type."""
|
||||
mapping = {
|
||||
"world": "regions",
|
||||
"region": "locations",
|
||||
"location": "pois",
|
||||
"poi": "items",
|
||||
}
|
||||
return mapping.get(node_type, "children")
|
||||
|
||||
def set_entity_position(
|
||||
self, entity_id: str, location_id: str, spatial_descriptor: str = ""
|
||||
) -> None:
|
||||
"""Set entity's spatial position."""
|
||||
if location_id not in self.nodes:
|
||||
logger.warning(f"Unknown location_id: {location_id}")
|
||||
return
|
||||
|
||||
node = self.nodes[location_id]
|
||||
if node.node_type not in ["location", "poi"]:
|
||||
logger.warning(
|
||||
f"Entity position must be at leaf level (location or poi), got {node.node_type}"
|
||||
)
|
||||
return
|
||||
|
||||
self.entity_positions[entity_id] = EntityPosition(
|
||||
entity_id=entity_id,
|
||||
location_id=location_id,
|
||||
spatial_descriptor=spatial_descriptor,
|
||||
)
|
||||
logger.info(
|
||||
f"Entity {entity_id} positioned at {location_id}: {spatial_descriptor}"
|
||||
)
|
||||
|
||||
def get_entity_position(self, entity_id: str) -> Optional[EntityPosition]:
|
||||
"""Get entity's current position."""
|
||||
return self.entity_positions.get(entity_id)
|
||||
|
||||
def get_entities_in_location(self, location_id: str) -> List[str]:
|
||||
"""Get all entities in a specific location."""
|
||||
return [
|
||||
entity_id
|
||||
for entity_id, pos in self.entity_positions.items()
|
||||
if pos.location_id == location_id
|
||||
]
|
||||
|
||||
def get_path_to_root(self, node_id: str) -> List[str]:
|
||||
"""Get path from node to root (world)."""
|
||||
path = []
|
||||
current = node_id
|
||||
while current:
|
||||
path.append(current)
|
||||
node = self.nodes.get(current)
|
||||
if not node:
|
||||
break
|
||||
current = node.parent_id
|
||||
return path
|
||||
|
||||
def get_immediate_children(self, node_id: str) -> List[str]:
|
||||
"""Get immediate children of a node."""
|
||||
node = self.nodes.get(node_id)
|
||||
return node.children if node else []
|
||||
|
||||
def get_all_descendants(self, node_id: str) -> List[str]:
|
||||
"""Get all descendants recursively."""
|
||||
descendants = []
|
||||
node = self.nodes.get(node_id)
|
||||
if not node:
|
||||
return descendants
|
||||
|
||||
for child_id in node.children:
|
||||
descendants.append(child_id)
|
||||
descendants.extend(self.get_all_descendants(child_id))
|
||||
|
||||
return descendants
|
||||
|
||||
def get_sibling_locations(self, node_id: str) -> List[str]:
|
||||
"""Get all sibling nodes (shared parent)."""
|
||||
node = self.nodes.get(node_id)
|
||||
if not node or not node.parent_id:
|
||||
return []
|
||||
|
||||
parent = self.nodes.get(node.parent_id)
|
||||
if not parent:
|
||||
return []
|
||||
|
||||
return [child for child in parent.children if child != node_id]
|
||||
|
||||
def get_connected_locations(
|
||||
self, node_id: str
|
||||
) -> List[Tuple[str, PortalConnection]]:
|
||||
"""
|
||||
Get all locations connected via portals.
|
||||
|
||||
Returns: List of (location_id, PortalConnection)
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
||||
connections = []
|
||||
for portal_conn in node.connections:
|
||||
target_node = self.nodes.get(portal_conn.target)
|
||||
if target_node:
|
||||
connections.append((portal_conn.target, portal_conn))
|
||||
|
||||
return connections
|
||||
|
||||
def broadcast_to_immediate_location(self, location_id: str) -> List[str]:
|
||||
"""
|
||||
Step 1: Immediate Leaf Check
|
||||
Get all entities in the same leaf location (no loss).
|
||||
"""
|
||||
return self.get_entities_in_location(location_id)
|
||||
|
||||
def broadcast_to_siblings_with_portals(
|
||||
self, location_id: str, action: str, llm_filter=None
|
||||
) -> Dict[str, PerceptionInfo]:
|
||||
"""
|
||||
Step 2: Sibling Check (Horizontal Propagation)
|
||||
Propagate perception to sibling locations through portals.
|
||||
|
||||
llm_filter: Callable that transforms action based on portal properties
|
||||
Returns: Dict mapping entity_id → PerceptionInfo
|
||||
"""
|
||||
perceptions: Dict[str, PerceptionInfo] = {}
|
||||
connected = self.get_connected_locations(location_id)
|
||||
|
||||
for target_id, portal_conn in connected:
|
||||
# Get entities in target location
|
||||
target_entities = self.get_entities_in_location(target_id)
|
||||
|
||||
# Transform action based on portal properties
|
||||
transformed_action = action
|
||||
if llm_filter:
|
||||
try:
|
||||
transformed_action = llm_filter(
|
||||
action=action,
|
||||
source_id=location_id,
|
||||
target_id=target_id,
|
||||
portal_conn=portal_conn,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM filter failed: {e}")
|
||||
transformed_action = action
|
||||
|
||||
for entity_id in target_entities:
|
||||
perceptions[entity_id] = PerceptionInfo(
|
||||
recipient_id=entity_id,
|
||||
original_action=action,
|
||||
transformed_action=transformed_action,
|
||||
propagation_path="portal",
|
||||
vision_clarity=portal_conn.vision_prop,
|
||||
sound_clarity=portal_conn.sound_prop,
|
||||
perceivable=transformed_action is not None,
|
||||
)
|
||||
|
||||
logger.info(f"Sibling broadcast: {len(perceptions)} entities reached")
|
||||
return perceptions
|
||||
|
||||
def broadcast_to_parent_children(
|
||||
self, location_id: str, action: str, escalation_check=None
|
||||
) -> Dict[str, PerceptionInfo]:
|
||||
"""
|
||||
Step 3: Parent Check (Vertical Propagation)
|
||||
Check if action warrants escalation to parent's other children.
|
||||
|
||||
escalation_check: Callable that determines if action warrants escalation
|
||||
Returns: Dict mapping entity_id → PerceptionInfo
|
||||
"""
|
||||
perceptions: Dict[str, PerceptionInfo] = {}
|
||||
|
||||
node = self.nodes.get(location_id)
|
||||
if not node or not node.parent_id:
|
||||
return perceptions
|
||||
|
||||
parent = self.nodes.get(node.parent_id)
|
||||
if not parent:
|
||||
return perceptions
|
||||
|
||||
# Check if action warrants escalation
|
||||
should_escalate = True
|
||||
if escalation_check:
|
||||
try:
|
||||
should_escalate = escalation_check(action, location_id, parent.id)
|
||||
except Exception as e:
|
||||
logger.error(f"Escalation check failed: {e}")
|
||||
|
||||
if not should_escalate:
|
||||
logger.info("Action did not warrant escalation to parent")
|
||||
return perceptions
|
||||
|
||||
# Broadcast to all siblings (other children of parent)
|
||||
for sibling_id in parent.children:
|
||||
if sibling_id == location_id:
|
||||
continue
|
||||
|
||||
sibling_entities = self.get_entities_in_location(sibling_id)
|
||||
for entity_id in sibling_entities:
|
||||
perceptions[entity_id] = PerceptionInfo(
|
||||
recipient_id=entity_id,
|
||||
original_action=action,
|
||||
transformed_action=f"[From {location_id}] {action}",
|
||||
propagation_path="escalated_from_sibling",
|
||||
vision_clarity=3,
|
||||
sound_clarity=4,
|
||||
perceivable=True,
|
||||
)
|
||||
|
||||
logger.info(f"Parent escalation: {len(perceptions)} entities reached")
|
||||
return perceptions
|
||||
|
||||
def bubble_up_broadcast(
|
||||
self,
|
||||
location_id: str,
|
||||
action: str,
|
||||
actor_id: str,
|
||||
llm_filter=None,
|
||||
escalation_check=None,
|
||||
) -> Dict[str, PerceptionInfo]:
|
||||
"""
|
||||
Full bubble-up algorithm:
|
||||
1. Immediate leaf check (all entities in same location)
|
||||
2. Sibling check (portal propagation)
|
||||
3. Parent check (escalation if warranted)
|
||||
|
||||
Returns: Dict mapping all perceiving entity_id → PerceptionInfo
|
||||
"""
|
||||
logger.info(f"Bubble-up broadcast from {location_id}: '{action}'")
|
||||
all_perceptions: Dict[str, PerceptionInfo] = {}
|
||||
|
||||
# Step 1: Immediate location (full perception)
|
||||
immediate = self.broadcast_to_immediate_location(location_id)
|
||||
for entity_id in immediate:
|
||||
if entity_id != actor_id: # Don't perceive own action
|
||||
all_perceptions[entity_id] = PerceptionInfo(
|
||||
recipient_id=entity_id,
|
||||
original_action=action,
|
||||
transformed_action=action,
|
||||
propagation_path="immediate",
|
||||
vision_clarity=10,
|
||||
sound_clarity=10,
|
||||
perceivable=True,
|
||||
)
|
||||
|
||||
# Step 2: Sibling propagation through portals
|
||||
sibling_perceptions = self.broadcast_to_siblings_with_portals(
|
||||
location_id, action, llm_filter
|
||||
)
|
||||
all_perceptions.update(sibling_perceptions)
|
||||
|
||||
# Step 3: Parent escalation
|
||||
parent_perceptions = self.broadcast_to_parent_children(
|
||||
location_id, action, escalation_check
|
||||
)
|
||||
all_perceptions.update(parent_perceptions)
|
||||
|
||||
logger.info(f"Total entities perceiving: {len(all_perceptions)}")
|
||||
return all_perceptions
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize spatial graph to dict."""
|
||||
return {
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"entity_positions": {
|
||||
entity_id: asdict(pos)
|
||||
for entity_id, pos in self.entity_positions.items()
|
||||
},
|
||||
}
|
||||
|
||||
def save_to_json(self, filepath: str) -> None:
|
||||
"""Save spatial graph to JSON file."""
|
||||
try:
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
logger.info(f"Spatial graph saved to {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save spatial graph: {e}")
|
||||
|
||||
@classmethod
|
||||
def load_from_json_file(cls, filepath: str) -> "SpatialGraph":
|
||||
"""Load spatial graph from JSON file."""
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
graph = cls()
|
||||
|
||||
# Rebuild nodes
|
||||
for node_id, node_data in data.get("nodes", {}).items():
|
||||
connections = [
|
||||
PortalConnection(**conn)
|
||||
for conn in node_data.get("connections", [])
|
||||
]
|
||||
node = SpatialNode(
|
||||
id=node_data["id"],
|
||||
name=node_data["name"],
|
||||
description=node_data["description"],
|
||||
node_type=node_data["node_type"],
|
||||
parent_id=node_data.get("parent_id"),
|
||||
children=node_data.get("children", []),
|
||||
connections=connections,
|
||||
)
|
||||
graph.nodes[node_id] = node
|
||||
|
||||
# Rebuild entity positions
|
||||
for entity_id, pos_data in data.get("entity_positions", {}).items():
|
||||
graph.entity_positions[entity_id] = EntityPosition(**pos_data)
|
||||
|
||||
logger.info(f"Spatial graph loaded from {filepath}")
|
||||
return graph
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load spatial graph: {e}")
|
||||
return cls()
|
||||
Reference in New Issue
Block a user