refactor: Refactor and split scenario_builder_tui
This commit is contained in:
File diff suppressed because it is too large
Load Diff
5
src/omnia/tools/scenario_tui/__init__.py
Normal file
5
src/omnia/tools/scenario_tui/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .entities import EntitiesScreen
|
||||
from .prompts import NodeTypePrompt, PathPrompt
|
||||
from .spatial_graph import SpatialGraphScreen
|
||||
|
||||
__all__ = ["EntitiesScreen", "NodeTypePrompt", "PathPrompt", "SpatialGraphScreen"]
|
||||
321
src/omnia/tools/scenario_tui/entities.py
Normal file
321
src/omnia/tools/scenario_tui/entities.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Horizontal, Vertical, VerticalScroll
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
TextArea,
|
||||
)
|
||||
|
||||
from omnia.tools.scenario_builder import DEFAULT_ENTITY
|
||||
|
||||
|
||||
class EntityItem(ListItem):
|
||||
def __init__(self, index: int, entity: dict[str, Any]) -> None:
|
||||
super().__init__(Label(self._label(entity)))
|
||||
self.index = index
|
||||
|
||||
@staticmethod
|
||||
def _label(entity: dict[str, Any]) -> str:
|
||||
name = str(entity.get("name") or "Unnamed").strip()
|
||||
entity_id = str(entity.get("id") or "no-id").strip()
|
||||
return f"{name} ({entity_id})"
|
||||
|
||||
|
||||
class EntitiesScreen(Screen[Optional[dict[str, Any]]]):
|
||||
CSS = """
|
||||
Screen {
|
||||
layout: vertical;
|
||||
}
|
||||
#entities-main {
|
||||
height: 1fr;
|
||||
}
|
||||
#entity-left {
|
||||
width: 35%;
|
||||
min-width: 24;
|
||||
border: round $primary;
|
||||
padding: 1;
|
||||
}
|
||||
#entity-right {
|
||||
width: 1fr;
|
||||
border: round $primary;
|
||||
padding: 1;
|
||||
}
|
||||
#entity_list {
|
||||
height: 1fr;
|
||||
}
|
||||
#entity-form {
|
||||
height: 1fr;
|
||||
}
|
||||
#entity_stats,
|
||||
#entity_voice,
|
||||
#entity_memories {
|
||||
height: 4;
|
||||
}
|
||||
#entity-traits-label,
|
||||
#entity-stats-label,
|
||||
#entity-voice-label,
|
||||
#entity-memories-label {
|
||||
margin-top: 1;
|
||||
}
|
||||
#entity_status {
|
||||
height: 1;
|
||||
padding: 0 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, scenario: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.scenario = scenario
|
||||
self.entities = self.scenario.setdefault("entities", [])
|
||||
if not isinstance(self.entities, list):
|
||||
self.entities = []
|
||||
self.scenario["entities"] = self.entities
|
||||
self.selected_index: Optional[int] = None
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with Horizontal(id="entities-main"):
|
||||
with Vertical(id="entity-left"):
|
||||
yield Label("Entities", id="entity-title")
|
||||
yield ListView(id="entity_list")
|
||||
with Horizontal(id="entity-actions"):
|
||||
yield Button("New", id="entity_new")
|
||||
yield Button("Delete", id="entity_delete")
|
||||
with Vertical(id="entity-right"):
|
||||
yield Label("Entity Details", id="entity-current")
|
||||
with VerticalScroll(id="entity-form"):
|
||||
yield Label("Entity ID", id="entity-id-label")
|
||||
yield Input(id="entity_id")
|
||||
yield Label("Name", id="entity-name-label")
|
||||
yield Input(id="entity_name")
|
||||
yield Label("Traits (comma-separated)", id="entity-traits-label")
|
||||
yield Input(id="entity_traits")
|
||||
yield Label("Stats (key=value per line)", id="entity-stats-label")
|
||||
yield TextArea(id="entity_stats")
|
||||
yield Label("Voice sample", id="entity-voice-label")
|
||||
yield TextArea(id="entity_voice")
|
||||
yield Label("Current mood", id="entity-mood-label")
|
||||
yield Input(id="entity_mood")
|
||||
yield Label("Location ID", id="entity-location-label")
|
||||
yield Input(id="entity_location")
|
||||
yield Label("Spatial descriptor", id="entity-spatial-label")
|
||||
yield Input(id="entity_spatial")
|
||||
yield Label(
|
||||
"Memories (one per line; JSON allowed)",
|
||||
id="entity-memories-label",
|
||||
)
|
||||
yield TextArea(id="entity_memories")
|
||||
with Horizontal(id="entity-actions-right"):
|
||||
yield Button("Save", id="entity_save", variant="primary")
|
||||
yield Button("Apply", id="entity_apply")
|
||||
yield Button("Cancel", id="entity_cancel")
|
||||
yield Label("", id="entity_status")
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self._reload_entity_list(select_index=0 if self.entities else None)
|
||||
if self.entities:
|
||||
self._load_entity_into_form(self.entities[0])
|
||||
self.selected_index = 0
|
||||
else:
|
||||
self._load_entity_into_form(self._default_entity())
|
||||
self.selected_index = None
|
||||
self._set_status("Editing entities")
|
||||
|
||||
def _set_status(self, message: str) -> None:
|
||||
self.query_one("#entity_status", Label).update(message)
|
||||
|
||||
def _default_entity(self) -> dict[str, Any]:
|
||||
return json.loads(json.dumps(DEFAULT_ENTITY))
|
||||
|
||||
def _load_entity_into_form(self, entity: dict[str, Any]) -> None:
|
||||
self.query_one("#entity_id", Input).value = str(entity.get("id", ""))
|
||||
self.query_one("#entity_name", Input).value = str(entity.get("name", ""))
|
||||
traits = entity.get("traits", [])
|
||||
if isinstance(traits, list):
|
||||
traits_text = ", ".join(str(t).strip() for t in traits if str(t).strip())
|
||||
else:
|
||||
traits_text = str(traits)
|
||||
self.query_one("#entity_traits", Input).value = traits_text
|
||||
|
||||
stats = entity.get("stats", {})
|
||||
stats_lines = []
|
||||
if isinstance(stats, dict):
|
||||
for key, value in stats.items():
|
||||
stats_lines.append(f"{key}={value}")
|
||||
self.query_one("#entity_stats", TextArea).text = "\n".join(stats_lines)
|
||||
|
||||
self.query_one("#entity_voice", TextArea).text = str(
|
||||
entity.get("voice_sample", "")
|
||||
)
|
||||
self.query_one("#entity_mood", Input).value = str(
|
||||
entity.get("current_mood", "")
|
||||
)
|
||||
|
||||
metadata = (
|
||||
entity.get("metadata", {})
|
||||
if isinstance(entity.get("metadata"), dict)
|
||||
else {}
|
||||
)
|
||||
self.query_one("#entity_location", Input).value = str(
|
||||
metadata.get("location", "")
|
||||
)
|
||||
self.query_one("#entity_spatial", Input).value = str(
|
||||
metadata.get("spatial_descriptor", "")
|
||||
)
|
||||
|
||||
memories = entity.get("memories", [])
|
||||
memory_lines = []
|
||||
if isinstance(memories, list):
|
||||
for memory in memories:
|
||||
if isinstance(memory, (dict, list)):
|
||||
memory_lines.append(json.dumps(memory))
|
||||
else:
|
||||
memory_lines.append(str(memory))
|
||||
self.query_one("#entity_memories", TextArea).text = "\n".join(memory_lines)
|
||||
|
||||
def _build_entity_from_form(self) -> Optional[dict[str, Any]]:
|
||||
entity_id = self.query_one("#entity_id", Input).value.strip()
|
||||
name = self.query_one("#entity_name", Input).value.strip()
|
||||
if not entity_id or not name:
|
||||
self._set_status("Entity requires non-empty ID and name")
|
||||
return None
|
||||
|
||||
entity = self._default_entity()
|
||||
entity["id"] = entity_id
|
||||
entity["name"] = name
|
||||
|
||||
traits_value = self.query_one("#entity_traits", Input).value
|
||||
traits = [t.strip() for t in traits_value.split(",") if t.strip()]
|
||||
entity["traits"] = traits
|
||||
|
||||
stats_text = self.query_one("#entity_stats", TextArea).text
|
||||
stats: dict[str, int] = {}
|
||||
for line in stats_text.splitlines():
|
||||
raw = line.strip()
|
||||
if not raw:
|
||||
continue
|
||||
if "=" not in raw:
|
||||
self._set_status(f"Invalid stats line (missing '='): {raw}")
|
||||
return None
|
||||
key, value = raw.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key or not value:
|
||||
self._set_status(f"Invalid stats line: {raw}")
|
||||
return None
|
||||
try:
|
||||
stats[key] = int(value)
|
||||
except ValueError:
|
||||
self._set_status(f"Stat value must be an integer: {raw}")
|
||||
return None
|
||||
entity["stats"] = stats
|
||||
|
||||
entity["voice_sample"] = self.query_one("#entity_voice", TextArea).text.strip()
|
||||
mood = self.query_one("#entity_mood", Input).value.strip()
|
||||
entity["current_mood"] = mood if mood else "Neutral"
|
||||
|
||||
entity["metadata"]["location"] = self.query_one(
|
||||
"#entity_location", Input
|
||||
).value.strip()
|
||||
entity["metadata"]["spatial_descriptor"] = self.query_one(
|
||||
"#entity_spatial", Input
|
||||
).value.strip()
|
||||
|
||||
memories_text = self.query_one("#entity_memories", TextArea).text
|
||||
memories: list[Any] = []
|
||||
for line in memories_text.splitlines():
|
||||
raw = line.strip()
|
||||
if not raw:
|
||||
continue
|
||||
if raw.startswith(("{", "[")):
|
||||
try:
|
||||
memories.append(json.loads(raw))
|
||||
except json.JSONDecodeError:
|
||||
self._set_status(f"Invalid memory JSON: {raw}")
|
||||
return None
|
||||
else:
|
||||
memories.append(raw)
|
||||
entity["memories"] = memories
|
||||
|
||||
return entity
|
||||
|
||||
def _reload_entity_list(self, select_index: Optional[int] = None) -> None:
|
||||
list_view = self.query_one("#entity_list", ListView)
|
||||
list_view.clear()
|
||||
for index, entity in enumerate(self.entities):
|
||||
list_view.append(EntityItem(index, entity))
|
||||
if select_index is not None and self.entities:
|
||||
clamped = max(0, min(select_index, len(self.entities) - 1))
|
||||
list_view.index = clamped
|
||||
self.selected_index = clamped
|
||||
|
||||
def on_list_view_selected(self, event: ListView.Selected) -> None:
|
||||
if isinstance(event.item, EntityItem):
|
||||
self.selected_index = event.item.index
|
||||
self._load_entity_into_form(self.entities[event.item.index])
|
||||
|
||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
button_id = event.button.id
|
||||
if button_id == "entity_new":
|
||||
self.selected_index = None
|
||||
self._load_entity_into_form(self._default_entity())
|
||||
self._set_status("Creating new entity")
|
||||
return
|
||||
|
||||
if button_id == "entity_save":
|
||||
entity = self._build_entity_from_form()
|
||||
if entity is None:
|
||||
return
|
||||
|
||||
existing_ids = {
|
||||
e.get("id")
|
||||
for i, e in enumerate(self.entities)
|
||||
if i != self.selected_index
|
||||
}
|
||||
if entity["id"] in existing_ids:
|
||||
self._set_status(f"Entity ID '{entity['id']}' already exists")
|
||||
return
|
||||
|
||||
if self.selected_index is None:
|
||||
self.entities.append(entity)
|
||||
self.selected_index = len(self.entities) - 1
|
||||
else:
|
||||
self.entities[self.selected_index] = entity
|
||||
|
||||
self._reload_entity_list(select_index=self.selected_index)
|
||||
self._set_status("Entity saved")
|
||||
return
|
||||
|
||||
if button_id == "entity_delete":
|
||||
if self.selected_index is None:
|
||||
self._set_status("Select an entity to delete")
|
||||
return
|
||||
removed = self.entities.pop(self.selected_index)
|
||||
self._reload_entity_list(select_index=self.selected_index)
|
||||
if self.entities:
|
||||
self._load_entity_into_form(
|
||||
self.entities[min(self.selected_index, len(self.entities) - 1)]
|
||||
)
|
||||
else:
|
||||
self._load_entity_into_form(self._default_entity())
|
||||
self.selected_index = None
|
||||
self._set_status(f"Deleted {removed.get('name')}")
|
||||
return
|
||||
|
||||
if button_id == "entity_apply":
|
||||
self.scenario["entities"] = self.entities
|
||||
self.dismiss(self.scenario)
|
||||
return
|
||||
|
||||
if button_id == "entity_cancel":
|
||||
self.dismiss(None)
|
||||
71
src/omnia/tools/scenario_tui/prompts.py
Normal file
71
src/omnia/tools/scenario_tui/prompts.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Grid, Horizontal
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Input, Label
|
||||
|
||||
|
||||
class PathPrompt(ModalScreen[Optional[str]]):
|
||||
def __init__(self, title: str, default: str = "") -> None:
|
||||
super().__init__()
|
||||
self.title = title
|
||||
self.default = default
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Label(self.title, id="prompt-title")
|
||||
yield Input(value=self.default, id="prompt-input")
|
||||
with Horizontal(id="prompt-buttons"):
|
||||
yield Button("OK", id="ok", variant="primary")
|
||||
yield Button("Cancel", id="cancel")
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.query_one(Input).focus()
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if event.button.id == "ok":
|
||||
value = self.query_one(Input).value.strip()
|
||||
self.dismiss(value or None)
|
||||
else:
|
||||
self.dismiss(None)
|
||||
|
||||
|
||||
class NodeTypePrompt(ModalScreen[Optional[str]]):
|
||||
CSS = """
|
||||
NodeTypePrompt {
|
||||
align: center middle;
|
||||
}
|
||||
#dialog {
|
||||
layout: grid;
|
||||
grid-size: 2;
|
||||
grid-gutter: 1 2;
|
||||
grid-rows: 1fr 3;
|
||||
padding: 0 1;
|
||||
width: 60;
|
||||
height: 11;
|
||||
border: thick $background 80%;
|
||||
background: $surface;
|
||||
}
|
||||
#node-type-title {
|
||||
column-span: 2;
|
||||
content-align: center middle;
|
||||
}
|
||||
"""
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
with Grid(id="dialog"):
|
||||
yield Label("Select node type", id="node-type-title")
|
||||
yield Button("🗺️ Region", id="node_region", variant="primary")
|
||||
yield Button("📍 Location", id="node_location")
|
||||
yield Button("⭐ POI", id="node_poi")
|
||||
yield Button("Cancel", id="node_cancel")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if event.button.id == "node_region":
|
||||
self.dismiss("region")
|
||||
elif event.button.id == "node_location":
|
||||
self.dismiss("location")
|
||||
elif event.button.id == "node_poi":
|
||||
self.dismiss("poi")
|
||||
else:
|
||||
self.dismiss(None)
|
||||
738
src/omnia/tools/scenario_tui/spatial_graph.py
Normal file
738
src/omnia/tools/scenario_tui/spatial_graph.py
Normal file
@@ -0,0 +1,738 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Horizontal, Vertical
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Label,
|
||||
Select,
|
||||
TextArea,
|
||||
Tree,
|
||||
)
|
||||
|
||||
from omnia.tools.scenario_builder import DEFAULT_SCENARIO
|
||||
from omnia.tools.scenario_tui.prompts import NodeTypePrompt
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpatialNodeRef:
|
||||
node_type: str
|
||||
obj: dict[str, Any]
|
||||
|
||||
|
||||
class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
|
||||
CSS = """
|
||||
Screen {
|
||||
layout: vertical;
|
||||
}
|
||||
#spatial-main {
|
||||
height: 1fr;
|
||||
}
|
||||
#spatial-left {
|
||||
width: 35%;
|
||||
min-width: 24;
|
||||
border: round $primary;
|
||||
padding: 1;
|
||||
}
|
||||
#spatial-right {
|
||||
width: 1fr;
|
||||
border: round $primary;
|
||||
padding: 1;
|
||||
}
|
||||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
#spatial-tree {
|
||||
height: 1fr;
|
||||
}
|
||||
#node-description {
|
||||
height: 5;
|
||||
}
|
||||
#spatial-status {
|
||||
height: 1;
|
||||
padding: 0 1;
|
||||
}
|
||||
#spatial-actions {
|
||||
height: auto;
|
||||
width: auto;
|
||||
align-horizontal: left;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, scenario: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.scenario = scenario
|
||||
self.selected_node = None
|
||||
self.node_map: dict[int, Any] = {}
|
||||
self.parent_option_map: dict[str, tuple[str, dict[str, Any]]] = {}
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with Horizontal(id="spatial-main"):
|
||||
with Vertical(id="spatial-left"):
|
||||
yield Label("Spatial Graph", id="spatial-title")
|
||||
yield Tree("🌍 World", id="spatial-tree")
|
||||
with Horizontal(id="spatial-tree-actions"):
|
||||
yield Button("Add", id="spatial_add", variant="primary")
|
||||
yield Button("Delete", id="spatial_delete")
|
||||
with Horizontal(id="spatial-nesting-actions"):
|
||||
yield Button("Promote", id="spatial_promote")
|
||||
yield Button("Demote", id="spatial_demote")
|
||||
with Vertical(id="spatial-right"):
|
||||
yield Label("Node Details", id="node-title")
|
||||
with Vertical(id="node-details"):
|
||||
yield Label("Type", id="node-type-label")
|
||||
yield Label("", id="node-type-value")
|
||||
yield Label("ID", id="node-id-label")
|
||||
yield Input(id="node_id")
|
||||
yield Label("Name", id="node-name-label")
|
||||
yield Input(id="node_name")
|
||||
yield Label("Description", id="node-desc-label")
|
||||
yield TextArea(id="node-description")
|
||||
yield Label("Parent", id="node-parent-label")
|
||||
yield Select([], prompt="Select parent", id="node_parent")
|
||||
with Horizontal(id="node-actions"):
|
||||
yield Button("Generate", id="node_generate")
|
||||
yield Button("Update", id="node_update", variant="primary")
|
||||
yield Button("Reset", id="node_reset")
|
||||
with Horizontal(id="spatial-actions"):
|
||||
yield Button("Save", id="spatial_save", variant="primary")
|
||||
yield Button("Cancel", id="spatial_cancel")
|
||||
yield Label("", id="spatial-status")
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self._ensure_spatial_graph()
|
||||
self._build_tree()
|
||||
self._clear_form()
|
||||
self._set_status("Select a node to edit")
|
||||
|
||||
def _set_status(self, message: str) -> None:
|
||||
self.query_one("#spatial-status", Label).update(message)
|
||||
|
||||
def _toast(self, message: str, severity: str = "warning") -> None:
|
||||
try:
|
||||
self.app.notify(message, severity=severity)
|
||||
except Exception:
|
||||
self._set_status(message)
|
||||
|
||||
def _ensure_spatial_graph(self) -> None:
|
||||
spatial_graph = self.scenario.get("spatial_graph")
|
||||
if spatial_graph is None:
|
||||
spatial_graph = json.loads(json.dumps(DEFAULT_SCENARIO["spatial_graph"]))
|
||||
self.scenario["spatial_graph"] = spatial_graph
|
||||
world = spatial_graph.setdefault("world", {})
|
||||
world.setdefault("regions", [])
|
||||
world.setdefault("name", "World")
|
||||
world.setdefault("id", "world")
|
||||
|
||||
def _world(self) -> dict[str, Any]:
|
||||
return self.scenario["spatial_graph"]["world"]
|
||||
|
||||
def _build_tree(
|
||||
self,
|
||||
select_obj: Optional[dict[str, Any]] = None,
|
||||
expanded: Optional[set[int]] = None,
|
||||
) -> None:
|
||||
tree = self.query_one("#spatial-tree", Tree)
|
||||
world = self._world()
|
||||
tree.reset(f"🌍 {world.get('name', 'World')}")
|
||||
tree.root.data = SpatialNodeRef("world", world)
|
||||
tree.root.expand()
|
||||
self.selected_node = None
|
||||
self.node_map = {}
|
||||
expanded = expanded or set()
|
||||
|
||||
def add_region(region: dict[str, Any]) -> None:
|
||||
node = tree.root.add(
|
||||
self._label_for("region", region), data=SpatialNodeRef("region", region)
|
||||
)
|
||||
self.node_map[id(region)] = node
|
||||
if id(region) in expanded:
|
||||
node.expand()
|
||||
for location in region.get("locations", []):
|
||||
add_location(node, location)
|
||||
|
||||
def add_location(parent_node: Any, location: dict[str, Any]) -> None:
|
||||
loc_node = parent_node.add(
|
||||
self._label_for("location", location),
|
||||
data=SpatialNodeRef("location", location),
|
||||
)
|
||||
self.node_map[id(location)] = loc_node
|
||||
if id(location) in expanded:
|
||||
loc_node.expand()
|
||||
for nested in location.get("locations", []):
|
||||
add_location(loc_node, nested)
|
||||
for poi in location.get("pois", []):
|
||||
poi_node = loc_node.add(
|
||||
self._label_for("poi", poi), data=SpatialNodeRef("poi", poi)
|
||||
)
|
||||
self.node_map[id(poi)] = poi_node
|
||||
if id(poi) in expanded:
|
||||
poi_node.expand()
|
||||
|
||||
for region in world.get("regions", []):
|
||||
add_region(region)
|
||||
|
||||
if select_obj and id(select_obj) in self.node_map:
|
||||
self.selected_node = self.node_map[id(select_obj)]
|
||||
self._expand_ancestors(self.selected_node)
|
||||
self._load_form_from_node(self.selected_node)
|
||||
else:
|
||||
self._clear_form()
|
||||
|
||||
def _label_for(self, node_type: str, obj: dict[str, Any]) -> str:
|
||||
name = str(obj.get("name") or "Unnamed").strip()
|
||||
node_id = str(obj.get("id") or "no-id").strip()
|
||||
icon = {"region": "🗺️", "location": "📍", "poi": "⭐"}.get(node_type, "•")
|
||||
return f"{icon} {name} ({node_id})"
|
||||
|
||||
def _capture_expanded_nodes(self) -> set[int]:
|
||||
tree = self.query_one("#spatial-tree", Tree)
|
||||
expanded: set[int] = set()
|
||||
|
||||
def walk(node: Any) -> None:
|
||||
if node.is_expanded and isinstance(node.data, SpatialNodeRef):
|
||||
expanded.add(id(node.data.obj))
|
||||
for child in node.children:
|
||||
walk(child)
|
||||
|
||||
walk(tree.root)
|
||||
return expanded
|
||||
|
||||
def _expand_ancestors(self, node: Any) -> None:
|
||||
current = node
|
||||
while current is not None:
|
||||
current.expand()
|
||||
current = current.parent
|
||||
|
||||
def _clear_form(self) -> None:
|
||||
self.query_one("#node-type-value", Label).update("")
|
||||
self.query_one("#node_id", Input).value = ""
|
||||
self.query_one("#node_name", Input).value = ""
|
||||
self.query_one("#node-description", TextArea).text = ""
|
||||
self.query_one("#node_parent", Select).set_options([])
|
||||
self.query_one("#node_parent", Select).value = Select.NULL
|
||||
self.query_one("#node-details", Vertical).add_class("hidden")
|
||||
self.query_one("#node_generate", Button).disabled = True
|
||||
self.query_one("#node_update", Button).disabled = True
|
||||
self.query_one("#node_reset", Button).disabled = True
|
||||
self.query_one("#spatial_delete", Button).disabled = True
|
||||
self.query_one("#spatial_promote", Button).disabled = True
|
||||
self.query_one("#spatial_demote", Button).disabled = True
|
||||
|
||||
def _load_form_from_node(self, node: Any) -> None:
|
||||
ref = node.data
|
||||
if not isinstance(ref, SpatialNodeRef) or ref.node_type == "world":
|
||||
self._clear_form()
|
||||
return
|
||||
self.query_one("#node-details", Vertical).remove_class("hidden")
|
||||
self.query_one("#node-type-value", Label).update(ref.node_type.title())
|
||||
self.query_one("#node_id", Input).value = str(ref.obj.get("id", ""))
|
||||
self.query_one("#node_name", Input).value = str(ref.obj.get("name", ""))
|
||||
self.query_one("#node-description", TextArea).text = str(
|
||||
ref.obj.get("description", "")
|
||||
)
|
||||
self.query_one("#node_generate", Button).disabled = False
|
||||
self.query_one("#node_update", Button).disabled = False
|
||||
self.query_one("#node_reset", Button).disabled = False
|
||||
self.query_one("#spatial_delete", Button).disabled = False
|
||||
is_location = ref.node_type == "location"
|
||||
can_promote, can_demote = (
|
||||
self._location_move_state(node) if is_location else (False, False)
|
||||
)
|
||||
self.query_one("#spatial_promote", Button).disabled = not can_promote
|
||||
self.query_one("#spatial_demote", Button).disabled = not can_demote
|
||||
parent_select = self.query_one("#node_parent", Select)
|
||||
if is_location:
|
||||
options, selected_value, option_map = self._build_parent_options(ref.obj)
|
||||
self.parent_option_map = option_map
|
||||
parent_select.set_options(options)
|
||||
parent_select.value = (
|
||||
selected_value if selected_value is not None else Select.NULL
|
||||
)
|
||||
parent_select.disabled = False
|
||||
else:
|
||||
self.parent_option_map = {}
|
||||
parent_select.set_options([])
|
||||
parent_select.value = Select.NULL
|
||||
parent_select.disabled = True
|
||||
|
||||
def on_tree_node_selected(self, event: Tree.NodeSelected) -> None:
|
||||
self.selected_node = event.node
|
||||
self._load_form_from_node(event.node)
|
||||
|
||||
def _open_node_type_prompt(self) -> None:
|
||||
self.app.push_screen(NodeTypePrompt(), callback=self._handle_node_type)
|
||||
|
||||
def _handle_node_type(self, result: Optional[str]) -> None:
|
||||
if result:
|
||||
self._add_node(result)
|
||||
|
||||
def _selected_ref(self) -> Optional[SpatialNodeRef]:
|
||||
if self.selected_node is None:
|
||||
return None
|
||||
ref = self.selected_node.data
|
||||
if isinstance(ref, SpatialNodeRef):
|
||||
return ref
|
||||
return None
|
||||
|
||||
def _selected_parent_node(self) -> Optional[Any]:
|
||||
if self.selected_node is None:
|
||||
return None
|
||||
return self.selected_node.parent
|
||||
|
||||
def _parent_locations_list(
|
||||
self, parent_node: Any
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
|
||||
return None
|
||||
parent_ref = parent_node.data
|
||||
if parent_ref.node_type not in {"region", "location"}:
|
||||
return None
|
||||
return parent_ref.obj.setdefault("locations", [])
|
||||
|
||||
def _location_move_state(self, node: Any) -> tuple[bool, bool]:
|
||||
if node is None or not isinstance(node.data, SpatialNodeRef):
|
||||
return (False, False)
|
||||
ref = node.data
|
||||
if ref.node_type != "location":
|
||||
return (False, False)
|
||||
parent_node = node.parent
|
||||
parent_locations = self._parent_locations_list(parent_node)
|
||||
if parent_locations is None:
|
||||
return (False, False)
|
||||
can_promote = False
|
||||
if parent_node is not None and isinstance(parent_node.data, SpatialNodeRef):
|
||||
parent_ref = parent_node.data
|
||||
can_promote = parent_ref.node_type == "location"
|
||||
can_demote = False
|
||||
if ref.obj in parent_locations:
|
||||
index = parent_locations.index(ref.obj)
|
||||
can_demote = index < len(parent_locations) - 1
|
||||
return (can_promote, can_demote)
|
||||
|
||||
def _parent_pois_list(self, parent_node: Any) -> Optional[list[dict[str, Any]]]:
|
||||
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
|
||||
return None
|
||||
parent_ref = parent_node.data
|
||||
if parent_ref.node_type != "location":
|
||||
return None
|
||||
return parent_ref.obj.setdefault("pois", [])
|
||||
|
||||
def _collect_locations(self) -> list[dict[str, Any]]:
|
||||
locations: list[dict[str, Any]] = []
|
||||
|
||||
def walk(location_list: list[dict[str, Any]]) -> None:
|
||||
for location in location_list:
|
||||
locations.append(location)
|
||||
walk(location.get("locations", []))
|
||||
|
||||
for region in self._world().get("regions", []):
|
||||
walk(region.get("locations", []))
|
||||
return locations
|
||||
|
||||
def _location_descendants(self, location: dict[str, Any]) -> set[int]:
|
||||
descendants: set[int] = set()
|
||||
|
||||
def walk(children: list[dict[str, Any]]) -> None:
|
||||
for child in children:
|
||||
descendants.add(id(child))
|
||||
walk(child.get("locations", []))
|
||||
|
||||
walk(location.get("locations", []))
|
||||
return descendants
|
||||
|
||||
def _find_location_parent(
|
||||
self, target: dict[str, Any]
|
||||
) -> Optional[tuple[str, dict[str, Any]]]:
|
||||
world = self._world()
|
||||
for region in world.get("regions", []):
|
||||
if target in region.get("locations", []):
|
||||
return ("region", region)
|
||||
for location in region.get("locations", []):
|
||||
found = self._find_location_parent_in_location(location, target)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
def _find_location_parent_in_location(
|
||||
self, location: dict[str, Any], target: dict[str, Any]
|
||||
) -> Optional[tuple[str, dict[str, Any]]]:
|
||||
if target in location.get("locations", []):
|
||||
return ("location", location)
|
||||
for nested in location.get("locations", []):
|
||||
found = self._find_location_parent_in_location(nested, target)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
def _build_parent_options(
|
||||
self, current: dict[str, Any]
|
||||
) -> tuple[list[tuple[str, str]], Optional[str], dict[str, tuple[str, dict[str, Any]]]]:
|
||||
options: list[tuple[str, str]] = []
|
||||
option_map: dict[str, tuple[str, dict[str, Any]]] = {}
|
||||
forbidden = {id(current)} | self._location_descendants(current)
|
||||
selected_value: Optional[str] = None
|
||||
|
||||
for region in self._world().get("regions", []):
|
||||
region_id = str(region.get("id", "")).strip()
|
||||
region_name = str(region.get("name", "")).strip()
|
||||
value = f"region@{id(region)}"
|
||||
label = (
|
||||
f"🗺️ {region_name} ({region_id})"
|
||||
if region_name
|
||||
else f"🗺️ {region_id or 'region'}"
|
||||
)
|
||||
options.append((label, value))
|
||||
option_map[value] = ("region", region)
|
||||
|
||||
for location in self._collect_locations():
|
||||
if id(location) in forbidden:
|
||||
continue
|
||||
location_id = str(location.get("id", "")).strip()
|
||||
location_name = str(location.get("name", "")).strip()
|
||||
value = f"location@{id(location)}"
|
||||
label = (
|
||||
f"📍 {location_name} ({location_id})"
|
||||
if location_name
|
||||
else f"📍 {location_id or 'location'}"
|
||||
)
|
||||
options.append((label, value))
|
||||
option_map[value] = ("location", location)
|
||||
|
||||
parent_info = self._find_location_parent(current)
|
||||
if parent_info:
|
||||
parent_type, parent_obj = parent_info
|
||||
for value, (entry_type, entry_obj) in option_map.items():
|
||||
if entry_type == parent_type and entry_obj is parent_obj:
|
||||
selected_value = value
|
||||
break
|
||||
|
||||
return options, selected_value, option_map
|
||||
|
||||
def _update_selected_node(self) -> None:
|
||||
ref = self._selected_ref()
|
||||
if ref is None or ref.node_type == "world":
|
||||
self._set_status("Select a node to update")
|
||||
return
|
||||
node_id = self.query_one("#node_id", Input).value.strip()
|
||||
name = self.query_one("#node_name", Input).value.strip()
|
||||
description = self.query_one("#node-description", TextArea).text.strip()
|
||||
if not node_id or not name:
|
||||
self._set_status("ID and name are required")
|
||||
return
|
||||
ref.obj["id"] = node_id
|
||||
ref.obj["name"] = name
|
||||
ref.obj["description"] = description
|
||||
if ref.node_type == "region":
|
||||
ref.obj.setdefault("locations", [])
|
||||
elif ref.node_type == "location":
|
||||
ref.obj.setdefault("locations", [])
|
||||
ref.obj.setdefault("pois", [])
|
||||
elif ref.node_type == "poi":
|
||||
ref.obj.setdefault("connections", [])
|
||||
if ref.node_type == "location":
|
||||
parent_select = self.query_one("#node_parent", Select)
|
||||
if parent_select.value is not Select.NULL:
|
||||
selected_value = str(parent_select.value)
|
||||
parent_info = self.parent_option_map.get(selected_value)
|
||||
if parent_info:
|
||||
parent_type, parent_obj = parent_info
|
||||
current_parent = self._find_location_parent(ref.obj)
|
||||
if current_parent is None or current_parent[1] is not parent_obj:
|
||||
if current_parent:
|
||||
current_type, current_obj = current_parent
|
||||
if current_type == "region":
|
||||
current_obj.get("locations", []).remove(ref.obj)
|
||||
elif current_type == "location":
|
||||
current_obj.get("locations", []).remove(ref.obj)
|
||||
parent_obj.setdefault("locations", []).append(ref.obj)
|
||||
expanded = self._capture_expanded_nodes()
|
||||
self._build_tree(select_obj=ref.obj, expanded=expanded)
|
||||
self._set_status("Node updated")
|
||||
return
|
||||
if self.selected_node:
|
||||
self.selected_node.label = self._label_for(ref.node_type, ref.obj)
|
||||
self._set_status("Node updated")
|
||||
|
||||
def _node_path(self, node: Any) -> str:
|
||||
parts = []
|
||||
current = node
|
||||
while current is not None and isinstance(current.data, SpatialNodeRef):
|
||||
ref = current.data
|
||||
if ref.node_type == "world":
|
||||
break
|
||||
label = str(ref.obj.get("name") or ref.obj.get("id") or ref.node_type)
|
||||
parts.append(f"{ref.node_type}: {label}")
|
||||
current = current.parent
|
||||
return " > ".join(reversed(parts)) if parts else "World"
|
||||
|
||||
def _collect_ids(self, node_type: str) -> set[str]:
|
||||
world = self._world()
|
||||
ids: set[str] = set()
|
||||
|
||||
if node_type == "region":
|
||||
for region in world.get("regions", []):
|
||||
if region_id := str(region.get("id", "")).strip():
|
||||
ids.add(region_id)
|
||||
return ids
|
||||
|
||||
def walk_locations(locations: list[dict[str, Any]]) -> None:
|
||||
for location in locations:
|
||||
if node_type == "location":
|
||||
if loc_id := str(location.get("id", "")).strip():
|
||||
ids.add(loc_id)
|
||||
if node_type == "poi":
|
||||
for poi in location.get("pois", []):
|
||||
if poi_id := str(poi.get("id", "")).strip():
|
||||
ids.add(poi_id)
|
||||
walk_locations(location.get("locations", []))
|
||||
|
||||
for region in world.get("regions", []):
|
||||
walk_locations(region.get("locations", []))
|
||||
|
||||
return ids
|
||||
|
||||
def _generate_node_fields(self) -> None:
|
||||
ref = self._selected_ref()
|
||||
if ref is None or ref.node_type == "world":
|
||||
self._toast("Select a node to generate")
|
||||
return
|
||||
|
||||
scenario_meta = self.scenario.get("scenario", {})
|
||||
node_type = ref.node_type
|
||||
system_prompt = """
|
||||
You are helping build a game scenario.
|
||||
Return ONLY a single-line JSON object with keys:
|
||||
id, name, description.
|
||||
- id: short snake_case identifier
|
||||
- description: 1-3 sentences
|
||||
No commentary. ASCII only.
|
||||
""".strip()
|
||||
|
||||
user_prompt = f"""
|
||||
Scenario title: {scenario_meta.get("title", "")}
|
||||
Scenario description: {scenario_meta.get("description", "")}
|
||||
Node type: {node_type}
|
||||
Node path: {self._node_path(self.selected_node)}
|
||||
Current id: {ref.obj.get("id", "")}
|
||||
Current name: {ref.obj.get("name", "")}
|
||||
""".strip()
|
||||
|
||||
from omnia.llm_runtime import invoke_llm
|
||||
|
||||
response_text = invoke_llm(
|
||||
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
||||
)
|
||||
|
||||
try:
|
||||
payload = json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
self._toast("LLM response was not valid JSON", severity="error")
|
||||
return
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
self._toast("LLM response must be a JSON object", severity="error")
|
||||
return
|
||||
|
||||
generated_id = str(payload.get("id", "")).strip()
|
||||
generated_name = str(payload.get("name", "")).strip()
|
||||
generated_description = str(payload.get("description", "")).strip()
|
||||
|
||||
if not generated_id or not generated_name:
|
||||
self._toast("Generated data missing id or name", severity="error")
|
||||
return
|
||||
|
||||
existing_ids = self._collect_ids(node_type)
|
||||
if generated_id in existing_ids and generated_id != str(ref.obj.get("id", "")).strip():
|
||||
self._toast(
|
||||
f"Generated id '{generated_id}' already exists", severity="warning"
|
||||
)
|
||||
return
|
||||
|
||||
ref.obj["id"] = generated_id
|
||||
ref.obj["name"] = generated_name
|
||||
ref.obj["description"] = generated_description
|
||||
if self.selected_node:
|
||||
self.selected_node.label = self._label_for(node_type, ref.obj)
|
||||
self._load_form_from_node(self.selected_node)
|
||||
self._set_status("Generated fields applied")
|
||||
|
||||
def _reset_selected_node(self) -> None:
|
||||
if self.selected_node is None:
|
||||
return
|
||||
self._load_form_from_node(self.selected_node)
|
||||
self._set_status("Reverted changes")
|
||||
|
||||
def _add_node(self, node_type: str) -> None:
|
||||
world = self._world()
|
||||
expanded = self._capture_expanded_nodes()
|
||||
if node_type == "region":
|
||||
region = {"id": "", "name": "", "description": "", "locations": []}
|
||||
world.setdefault("regions", []).append(region)
|
||||
self._build_tree(select_obj=region, expanded=expanded)
|
||||
self._set_status("Region added")
|
||||
return
|
||||
|
||||
parent_node = self.selected_node
|
||||
if node_type == "location":
|
||||
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
|
||||
self._toast("Select a region or location for the new location")
|
||||
return
|
||||
parent_ref = parent_node.data
|
||||
if parent_ref.node_type not in {"region", "location"}:
|
||||
self._toast("Select a region or location for the new location")
|
||||
return
|
||||
location = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"locations": [],
|
||||
"pois": [],
|
||||
}
|
||||
parent_ref.obj.setdefault("locations", []).append(location)
|
||||
self._build_tree(select_obj=location, expanded=expanded)
|
||||
self._set_status("Location added")
|
||||
return
|
||||
|
||||
if node_type == "poi":
|
||||
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
|
||||
self._toast("Select a location for the new POI")
|
||||
return
|
||||
parent_ref = parent_node.data
|
||||
if parent_ref.node_type == "poi":
|
||||
parent_node = parent_node.parent
|
||||
if parent_node is None or not isinstance(
|
||||
parent_node.data, SpatialNodeRef
|
||||
):
|
||||
self._toast("Select a location for the new POI")
|
||||
return
|
||||
parent_ref = parent_node.data
|
||||
if parent_ref.node_type != "location":
|
||||
self._toast("Select a location for the new POI")
|
||||
return
|
||||
poi = {"id": "", "name": "", "description": "", "connections": []}
|
||||
parent_ref.obj.setdefault("pois", []).append(poi)
|
||||
self._build_tree(select_obj=poi, expanded=expanded)
|
||||
self._set_status("POI added")
|
||||
|
||||
def _delete_node(self) -> None:
|
||||
ref = self._selected_ref()
|
||||
if ref is None or ref.node_type == "world":
|
||||
self._set_status("Select a node to delete")
|
||||
return
|
||||
parent_node = self._selected_parent_node()
|
||||
expanded = self._capture_expanded_nodes()
|
||||
if ref.node_type == "region":
|
||||
world = self._world()
|
||||
world["regions"] = [r for r in world.get("regions", []) if r is not ref.obj]
|
||||
self._build_tree(expanded=expanded)
|
||||
self._set_status("Region deleted")
|
||||
return
|
||||
if ref.node_type == "location":
|
||||
parent_locations = self._parent_locations_list(parent_node)
|
||||
if parent_locations is None:
|
||||
self._set_status("Unable to delete location")
|
||||
return
|
||||
parent_locations[:] = [
|
||||
loc for loc in parent_locations if loc is not ref.obj
|
||||
]
|
||||
self._build_tree(expanded=expanded)
|
||||
self._set_status("Location deleted")
|
||||
return
|
||||
if ref.node_type == "poi":
|
||||
parent_pois = self._parent_pois_list(parent_node)
|
||||
if parent_pois is None:
|
||||
self._set_status("Unable to delete POI")
|
||||
return
|
||||
parent_pois[:] = [poi for poi in parent_pois if poi is not ref.obj]
|
||||
self._build_tree(expanded=expanded)
|
||||
self._set_status("POI deleted")
|
||||
|
||||
def _promote_location(self) -> None:
|
||||
ref = self._selected_ref()
|
||||
if ref is None or ref.node_type != "location":
|
||||
self._set_status("Select a location to promote")
|
||||
return
|
||||
parent_node = self._selected_parent_node()
|
||||
if parent_node is None or not isinstance(parent_node.data, SpatialNodeRef):
|
||||
self._set_status("Unable to promote location")
|
||||
return
|
||||
parent_ref = parent_node.data
|
||||
if parent_ref.node_type != "location":
|
||||
self._set_status("Location is already top-level")
|
||||
return
|
||||
grandparent_node = parent_node.parent
|
||||
parent_locations = parent_ref.obj.setdefault("locations", [])
|
||||
if ref.obj in parent_locations:
|
||||
parent_locations.remove(ref.obj)
|
||||
target_list = self._parent_locations_list(grandparent_node)
|
||||
if target_list is None:
|
||||
self._set_status("Unable to promote location")
|
||||
return
|
||||
target_list.append(ref.obj)
|
||||
expanded = self._capture_expanded_nodes()
|
||||
self._build_tree(select_obj=ref.obj, expanded=expanded)
|
||||
self._set_status("Location promoted")
|
||||
|
||||
def _demote_location(self) -> None:
|
||||
ref = self._selected_ref()
|
||||
if ref is None or ref.node_type != "location":
|
||||
self._set_status("Select a location to demote")
|
||||
return
|
||||
parent_node = self._selected_parent_node()
|
||||
parent_locations = self._parent_locations_list(parent_node)
|
||||
if parent_locations is None:
|
||||
self._set_status("Unable to demote location")
|
||||
return
|
||||
try:
|
||||
index = parent_locations.index(ref.obj)
|
||||
except ValueError:
|
||||
self._set_status("Unable to demote location")
|
||||
return
|
||||
if index >= len(parent_locations) - 1:
|
||||
self._set_status("No next sibling to demote under")
|
||||
return
|
||||
new_parent = parent_locations[index + 1]
|
||||
new_parent.setdefault("locations", []).append(ref.obj)
|
||||
parent_locations.remove(ref.obj)
|
||||
expanded = self._capture_expanded_nodes()
|
||||
self._build_tree(select_obj=ref.obj, expanded=expanded)
|
||||
self._set_status("Location demoted")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
button_id = event.button.id
|
||||
if button_id == "spatial_add":
|
||||
self._open_node_type_prompt()
|
||||
return
|
||||
if button_id == "spatial_delete":
|
||||
self._delete_node()
|
||||
return
|
||||
if button_id == "spatial_promote":
|
||||
self._promote_location()
|
||||
return
|
||||
if button_id == "spatial_demote":
|
||||
self._demote_location()
|
||||
return
|
||||
if button_id == "node_generate":
|
||||
self._generate_node_fields()
|
||||
return
|
||||
if button_id == "node_update":
|
||||
self._update_selected_node()
|
||||
return
|
||||
if button_id == "node_reset":
|
||||
self._reset_selected_node()
|
||||
return
|
||||
if button_id == "spatial_save":
|
||||
self.dismiss(self.scenario)
|
||||
return
|
||||
if button_id == "spatial_cancel":
|
||||
self.dismiss(None)
|
||||
Reference in New Issue
Block a user