feat: Allow changing of parent node and fix tree collapse

This commit is contained in:
2026-04-18 18:04:38 +05:30
parent 88bd58d984
commit a38d078d51

View File

@@ -451,6 +451,7 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
self.scenario = scenario
self.selected_node = None
self.node_map: dict[int, Any] = {}
self.parent_option_map: dict[str, tuple[str, dict[str, Any]]] = {}
def compose(self) -> ComposeResult:
yield Header(show_clock=True)
@@ -471,14 +472,16 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
yield Label("", id="node-type-value")
yield Label("ID", id="node-id-label")
yield Input(id="node_id")
yield Label("Name", id="node-name-label")
yield Input(id="node_name")
yield Label("Description", id="node-desc-label")
yield TextArea(id="node-description")
with Horizontal(id="node-actions"):
yield Button("Generate", id="node_generate")
yield Button("Update", id="node_update", variant="primary")
yield Button("Reset", id="node_reset")
yield Label("Name", id="node-name-label")
yield Input(id="node_name")
yield Label("Description", id="node-desc-label")
yield TextArea(id="node-description")
yield Label("Parent", id="node-parent-label")
yield Select([], prompt="Select parent", id="node_parent")
with Horizontal(id="node-actions"):
yield Button("Generate", id="node_generate")
yield Button("Update", id="node_update", variant="primary")
yield Button("Reset", id="node_reset")
with Horizontal(id="spatial-actions"):
yield Button("Save", id="spatial_save", variant="primary")
yield Button("Cancel", id="spatial_cancel")
@@ -513,7 +516,11 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
def _world(self) -> dict[str, Any]:
return self.scenario["spatial_graph"]["world"]
def _build_tree(self, select_obj: Optional[dict[str, Any]] = None) -> None:
def _build_tree(
self,
select_obj: Optional[dict[str, Any]] = None,
expanded: Optional[set[int]] = None,
) -> None:
tree = self.query_one("#spatial-tree", Tree)
world = self._world()
tree.reset(f"🌍 {world.get('name', 'World')}")
@@ -521,12 +528,15 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
tree.root.expand()
self.selected_node = None
self.node_map = {}
expanded = expanded or set()
def add_region(region: dict[str, Any]) -> None:
node = tree.root.add(
self._label_for("region", region), data=SpatialNodeRef("region", region)
)
self.node_map[id(region)] = node
if id(region) in expanded:
node.expand()
for location in region.get("locations", []):
add_location(node, location)
@@ -536,6 +546,8 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
data=SpatialNodeRef("location", location),
)
self.node_map[id(location)] = loc_node
if id(location) in expanded:
loc_node.expand()
for nested in location.get("locations", []):
add_location(loc_node, nested)
for poi in location.get("pois", []):
@@ -543,16 +555,38 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
self._label_for("poi", poi), data=SpatialNodeRef("poi", poi)
)
self.node_map[id(poi)] = poi_node
if id(poi) in expanded:
poi_node.expand()
for region in world.get("regions", []):
add_region(region)
if select_obj and id(select_obj) in self.node_map:
self.selected_node = self.node_map[id(select_obj)]
self._expand_ancestors(self.selected_node)
self._load_form_from_node(self.selected_node)
else:
self._clear_form()
def _capture_expanded_nodes(self) -> set[int]:
tree = self.query_one("#spatial-tree", Tree)
expanded: set[int] = set()
def walk(node: Any) -> None:
if node.is_expanded and isinstance(node.data, SpatialNodeRef):
expanded.add(id(node.data.obj))
for child in node.children:
walk(child)
walk(tree.root)
return expanded
def _expand_ancestors(self, node: Any) -> None:
current = node
while current is not None:
current.expand()
current = current.parent
def _label_for(self, node_type: str, obj: dict[str, Any]) -> str:
name = str(obj.get("name") or "Unnamed").strip()
node_id = str(obj.get("id") or "no-id").strip()
@@ -564,6 +598,8 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
self.query_one("#node_id", Input).value = ""
self.query_one("#node_name", Input).value = ""
self.query_one("#node-description", TextArea).text = ""
self.query_one("#node_parent", Select).set_options([])
self.query_one("#node_parent", Select).value = Select.NULL
self.query_one("#node-details", Vertical).add_class("hidden")
self.query_one("#node_generate", Button).disabled = True
self.query_one("#node_update", Button).disabled = True
@@ -594,6 +630,20 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
)
self.query_one("#spatial_promote", Button).disabled = not can_promote
self.query_one("#spatial_demote", Button).disabled = not can_demote
parent_select = self.query_one("#node_parent", Select)
if is_location:
options, selected_value, option_map = self._build_parent_options(ref.obj)
self.parent_option_map = option_map
parent_select.set_options(options)
parent_select.value = (
selected_value if selected_value is not None else Select.NULL
)
parent_select.disabled = False
else:
self.parent_option_map = {}
parent_select.set_options([])
parent_select.value = Select.NULL
parent_select.disabled = True
def on_tree_node_selected(self, event: Tree.NodeSelected) -> None:
self.selected_node = event.node
@@ -657,6 +707,91 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
return None
return parent_ref.obj.setdefault("pois", [])
def _collect_locations(self) -> list[dict[str, Any]]:
locations: list[dict[str, Any]] = []
def walk(location_list: list[dict[str, Any]]) -> None:
for location in location_list:
locations.append(location)
walk(location.get("locations", []))
for region in self._world().get("regions", []):
walk(region.get("locations", []))
return locations
def _location_descendants(self, location: dict[str, Any]) -> set[int]:
descendants: set[int] = set()
def walk(children: list[dict[str, Any]]) -> None:
for child in children:
descendants.add(id(child))
walk(child.get("locations", []))
walk(location.get("locations", []))
return descendants
def _find_location_parent(self, target: dict[str, Any]) -> Optional[tuple[str, dict[str, Any]]]:
world = self._world()
for region in world.get("regions", []):
if target in region.get("locations", []):
return ("region", region)
for location in region.get("locations", []):
found = self._find_location_parent_in_location(location, target)
if found:
return found
return None
def _find_location_parent_in_location(
self, location: dict[str, Any], target: dict[str, Any]
) -> Optional[tuple[str, dict[str, Any]]]:
if target in location.get("locations", []):
return ("location", location)
for nested in location.get("locations", []):
found = self._find_location_parent_in_location(nested, target)
if found:
return found
return None
def _build_parent_options(
self, current: dict[str, Any]
) -> tuple[list[tuple[str, str]], Optional[str], dict[str, tuple[str, dict[str, Any]]]]:
options: list[tuple[str, str]] = []
option_map: dict[str, tuple[str, dict[str, Any]]] = {}
forbidden = {id(current)} | self._location_descendants(current)
selected_value: Optional[str] = None
for region in self._world().get("regions", []):
region_id = str(region.get("id", "")).strip()
region_name = str(region.get("name", "")).strip()
value = f"region@{id(region)}"
label = f"🗺️ {region_name} ({region_id})" if region_name else f"🗺️ {region_id or 'region'}"
options.append((label, value))
option_map[value] = ("region", region)
for location in self._collect_locations():
if id(location) in forbidden:
continue
location_id = str(location.get("id", "")).strip()
location_name = str(location.get("name", "")).strip()
value = f"location@{id(location)}"
label = (
f"📍 {location_name} ({location_id})"
if location_name
else f"📍 {location_id or 'location'}"
)
options.append((label, value))
option_map[value] = ("location", location)
parent_info = self._find_location_parent(current)
if parent_info:
parent_type, parent_obj = parent_info
for value, (entry_type, entry_obj) in option_map.items():
if entry_type == parent_type and entry_obj is parent_obj:
selected_value = value
break
return options, selected_value, option_map
def _update_selected_node(self) -> None:
ref = self._selected_ref()
if ref is None or ref.node_type == "world":
@@ -678,6 +813,26 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]):
ref.obj.setdefault("pois", [])
elif ref.node_type == "poi":
ref.obj.setdefault("connections", [])
if ref.node_type == "location":
parent_select = self.query_one("#node_parent", Select)
if parent_select.value is not Select.NULL:
selected_value = str(parent_select.value)
parent_info = self.parent_option_map.get(selected_value)
if parent_info:
parent_type, parent_obj = parent_info
current_parent = self._find_location_parent(ref.obj)
if current_parent is None or current_parent[1] is not parent_obj:
if current_parent:
current_type, current_obj = current_parent
if current_type == "region":
current_obj.get("locations", []).remove(ref.obj)
elif current_type == "location":
current_obj.get("locations", []).remove(ref.obj)
parent_obj.setdefault("locations", []).append(ref.obj)
expanded = self._capture_expanded_nodes()
self._build_tree(select_obj=ref.obj, expanded=expanded)
self._set_status("Node updated")
return
if self.selected_node:
self.selected_node.label = self._label_for(ref.node_type, ref.obj)
self._set_status("Node updated")
@@ -791,10 +946,11 @@ Current name: {ref.obj.get("name", "")}
def _add_node(self, node_type: str) -> None:
world = self._world()
expanded = self._capture_expanded_nodes()
if node_type == "region":
region = {"id": "", "name": "", "description": "", "locations": []}
world.setdefault("regions", []).append(region)
self._build_tree(select_obj=region)
self._build_tree(select_obj=region, expanded=expanded)
self._set_status("Region added")
return
@@ -815,7 +971,7 @@ Current name: {ref.obj.get("name", "")}
"pois": [],
}
parent_ref.obj.setdefault("locations", []).append(location)
self._build_tree(select_obj=location)
self._build_tree(select_obj=location, expanded=expanded)
self._set_status("Location added")
return
@@ -837,7 +993,7 @@ Current name: {ref.obj.get("name", "")}
return
poi = {"id": "", "name": "", "description": "", "connections": []}
parent_ref.obj.setdefault("pois", []).append(poi)
self._build_tree(select_obj=poi)
self._build_tree(select_obj=poi, expanded=expanded)
self._set_status("POI added")
def _delete_node(self) -> None:
@@ -846,10 +1002,11 @@ Current name: {ref.obj.get("name", "")}
self._set_status("Select a node to delete")
return
parent_node = self._selected_parent_node()
expanded = self._capture_expanded_nodes()
if ref.node_type == "region":
world = self._world()
world["regions"] = [r for r in world.get("regions", []) if r is not ref.obj]
self._build_tree()
self._build_tree(expanded=expanded)
self._set_status("Region deleted")
return
if ref.node_type == "location":
@@ -860,7 +1017,7 @@ Current name: {ref.obj.get("name", "")}
parent_locations[:] = [
loc for loc in parent_locations if loc is not ref.obj
]
self._build_tree()
self._build_tree(expanded=expanded)
self._set_status("Location deleted")
return
if ref.node_type == "poi":
@@ -869,7 +1026,7 @@ Current name: {ref.obj.get("name", "")}
self._set_status("Unable to delete POI")
return
parent_pois[:] = [poi for poi in parent_pois if poi is not ref.obj]
self._build_tree()
self._build_tree(expanded=expanded)
self._set_status("POI deleted")
def _promote_location(self) -> None:
@@ -894,7 +1051,8 @@ Current name: {ref.obj.get("name", "")}
self._set_status("Unable to promote location")
return
target_list.append(ref.obj)
self._build_tree(select_obj=ref.obj)
expanded = self._capture_expanded_nodes()
self._build_tree(select_obj=ref.obj, expanded=expanded)
self._set_status("Location promoted")
def _demote_location(self) -> None:
@@ -918,7 +1076,8 @@ Current name: {ref.obj.get("name", "")}
new_parent = parent_locations[index + 1]
new_parent.setdefault("locations", []).append(ref.obj)
parent_locations.remove(ref.obj)
self._build_tree(select_obj=ref.obj)
expanded = self._capture_expanded_nodes()
self._build_tree(select_obj=ref.obj, expanded=expanded)
self._set_status("Location demoted")
def on_button_pressed(self, event: Button.Pressed) -> None: