fast registry load

This commit is contained in:
Xuwznln
2026-03-22 04:14:47 +08:00
parent 427afe83d4
commit d8922884b1
2 changed files with 104 additions and 9 deletions

View File

@@ -12,6 +12,7 @@ import io
import os
import sys
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
@@ -102,6 +103,7 @@ class Registry:
self.device_type_registry: Dict[str, Any] = {}
self.resource_type_registry: Dict[str, Any] = {}
self._type_resolve_cache: Dict[str, Any] = {}
self._setup_called = False
self._startup_executor: Optional[ThreadPoolExecutor] = None
@@ -412,10 +414,20 @@ class Registry:
# ------------------------------------------------------------------
def _replace_type_with_class(self, type_name: str, device_id: str, field_name: str) -> Any:
"""将类型名称替换为实际的 ROS 消息类对象"""
"""将类型名称替换为实际的 ROS 消息类对象(带缓存)"""
if not type_name or type_name == "":
return type_name
cached = self._type_resolve_cache.get(type_name)
if cached is not None:
return cached
result = self._resolve_type_uncached(type_name, device_id, field_name)
self._type_resolve_cache[type_name] = result
return result
def _resolve_type_uncached(self, type_name: str, device_id: str, field_name: str) -> Any:
"""实际的类型解析逻辑(无缓存)"""
# 泛型类型映射
if "[" in type_name:
generic_mapping = {
@@ -451,7 +463,6 @@ class Registry:
if type_class:
return type_class
else:
# dataclass / TypedDict 等非 ROS2 类型,序列化为 JSON 字符串
logger.trace(
f"[Registry] 类型 '{type_name}' 非 ROS2 消息类型 (设备 {device_id} {field_name}),映射为 String"
)
@@ -1151,8 +1162,8 @@ class Registry:
tmp = cache_path.with_suffix(".tmp")
tmp.write_bytes(pickle.dumps(cache, protocol=pickle.HIGHEST_PROTOCOL))
tmp.replace(cache_path)
except Exception:
pass
except Exception as e:
logger.debug(f"[UniLab Registry] 缓存保存失败: {e}")
@staticmethod
def _module_source_hash(module_str: str) -> Optional[str]:
@@ -1454,11 +1465,16 @@ class Registry:
仅做 ROS2 消息类型查找,不 import 任何设备模块,速度快且无副作用。
"""
t0 = time.time()
for device_id in list(self.device_type_registry):
try:
self.resolve_types_for_device(device_id)
except Exception as e:
logger.debug(f"[Registry] 设备 {device_id} 类型解析失败: {e}")
logger.info(
f"[UniLab Registry] 类型解析完成: {len(self.device_type_registry)} 设备 "
f"(耗时 {time.time() - t0:.2f}s)"
)
# ------------------------------------------------------------------
# YAML 注册表加载 (兼容旧格式)
@@ -1946,7 +1962,31 @@ class Registry:
return data, complete_data, True, device_ids
def _rebuild_device_runtime_data(self, complete_data: Dict[str, Any]) -> Dict[str, Any]:
"""从 complete_data纯字符串重建运行时数据type 字段替换为 class 对象)。"""
data = copy.deepcopy(complete_data)
for device_id, device_config in data.items():
if "class" not in device_config:
continue
# status_types: str → class
for st_name, st_type in device_config["class"].get("status_types", {}).items():
if isinstance(st_type, str):
device_config["class"]["status_types"][st_name] = self._replace_type_with_class(
st_type, device_id, f"状态 {st_name}"
)
# action type: str → class (non-UniLabJsonCommand only)
for _act_name, act_cfg in device_config["class"].get("action_value_mappings", {}).items():
t_ref = act_cfg.get("type", "")
if isinstance(t_ref, str) and t_ref and not t_ref.startswith("UniLabJsonCommand"):
resolved = self._replace_type_with_class(t_ref, device_id, f"动作 {_act_name}")
if resolved:
act_cfg["type"] = resolved
self._add_builtin_actions(device_config, device_id)
return data
def load_device_types(self, path: os.PathLike, complete_registry: bool = False):
import hashlib as _hl
t0 = time.time()
abs_path = Path(path).absolute()
devices_path = abs_path / "devices"
device_comms_path = abs_path / "device_comms"
@@ -1959,12 +1999,41 @@ class Registry:
if not files:
return
config_cache = self._load_config_cache()
yaml_dev_cache: dict = config_cache.get("_yaml_devices", {})
cache_hits = 0
uncached_files: list[Path] = []
if complete_registry:
uncached_files = files
else:
for file in files:
file_key = str(file.absolute()).replace("\\", "/")
try:
yaml_md5 = _hl.md5(file.read_bytes()).hexdigest()
except OSError:
uncached_files.append(file)
continue
cached = yaml_dev_cache.get(file_key)
if cached and cached.get("yaml_md5") == yaml_md5 and cached.get("entries"):
complete_data = cached["entries"]
# 过滤掉 AST 已有的设备
complete_data = {
did: cfg for did, cfg in complete_data.items()
if not self.device_type_registry.get(did)
}
runtime_data = self._rebuild_device_runtime_data(complete_data)
self.device_type_registry.update(runtime_data)
cache_hits += 1
continue
uncached_files.append(file)
executor = self._startup_executor
future_to_file = {
executor.submit(
self._load_single_device_file, file, complete_registry
): file
for file in files
for file in uncached_files
}
for future in as_completed(future_to_file):
@@ -1974,9 +2043,33 @@ class Registry:
if is_valid:
runtime_data = {did: data[did] for did in device_ids if did in data}
self.device_type_registry.update(runtime_data)
# 写入缓存
file_key = str(file.absolute()).replace("\\", "/")
try:
yaml_md5 = _hl.md5(file.read_bytes()).hexdigest()
yaml_dev_cache[file_key] = {
"yaml_md5": yaml_md5,
"entries": _complete_data,
}
except OSError:
pass
except Exception as e:
logger.warning(f"[UniLab Registry] 加载设备文件失败: {file}, 错误: {e}")
if uncached_files and yaml_dev_cache:
latest_cache = self._load_config_cache()
latest_cache["_yaml_devices"] = yaml_dev_cache
self._save_config_cache(latest_cache)
total = len(files)
extra = " (complete_registry 跳过缓存)" if complete_registry else ""
logger.info(
f"[UniLab Registry] YAML 设备加载: "
f"{cache_hits}/{total} 缓存命中, "
f"{len(uncached_files)} 重新加载 "
f"(耗时 {time.time() - t0:.2f}s){extra}"
)
# ------------------------------------------------------------------
# 注册表信息输出
# ------------------------------------------------------------------

View File

@@ -41,6 +41,7 @@ class ImportManager:
self._modules: Dict[str, Any] = {}
self._classes: Dict[str, Type] = {}
self._functions: Dict[str, Callable] = {}
self._search_miss: set = set()
if module_list:
for module_path in module_list:
@@ -155,29 +156,30 @@ class ImportManager:
Returns:
找到的类对象如果未找到则返回None
"""
# 如果cls_name是builtins中的关键字则返回对应类
if class_name in builtins.__dict__:
return builtins.__dict__[class_name]
# 首先在已索引的类中查找
if class_name in self._classes:
return self._classes[class_name]
cache_key = class_name.lower() if search_lower else class_name
if cache_key in self._search_miss:
return None
if search_lower:
classes = {name.lower(): obj for name, obj in self._classes.items()}
if class_name in classes:
return classes[class_name]
# 遍历所有已加载的模块进行搜索
for module_path, module in self._modules.items():
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and (
(name.lower() == class_name.lower()) if search_lower else (name == class_name)
):
# 将找到的类添加到索引中
self._classes[name] = obj
self._classes[f"{module_path}:{name}"] = obj
return obj
self._search_miss.add(cache_key)
return None
def get_enhanced_class_info(self, module_path: str, **_kwargs) -> Dict[str, Any]: