mirror of
https://github.com/deepmodeling/Uni-Lab-OS
synced 2026-03-24 09:17:39 +00:00
fast registry load minor fix on skill & registry stripe ros2 schema desc add create-device-skill new registry system backwards to yaml remove not exist resource new registry sys exp. support with add device add ai conventions correct raise create resource error ret info fix revert ret info fix fix prcxi check add create_resource schema re signal host ready event add websocket connection timeout and improve reconnection logic add open_timeout parameter to websocket connection add TimeoutError and InvalidStatus exception handling implement exponential backoff for reconnection attempts simplify reconnection logic flow add gzip change pose extra to any add isFlapY
419 lines
14 KiB
Python
419 lines
14 KiB
Python
"""
|
||
导入管理器
|
||
|
||
该模块提供了一个动态导入和管理模块的系统,避免误删未使用的导入。
|
||
"""
|
||
|
||
import builtins
|
||
import importlib
|
||
import inspect
|
||
import sys
|
||
import traceback
|
||
import ast
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Dict, List, Any, Optional, Callable, Type, Union, Tuple
|
||
|
||
__all__ = [
|
||
"ImportManager",
|
||
"default_manager",
|
||
"load_module",
|
||
"get_class",
|
||
"get_module",
|
||
"init_from_list",
|
||
"get_enhanced_class_info",
|
||
]
|
||
|
||
from unilabos.resources.resource_tracker import PARAM_SAMPLE_UUIDS
|
||
from unilabos.utils import logger
|
||
|
||
|
||
class ImportManager:
|
||
"""导入管理器类,用于动态加载和管理模块"""
|
||
|
||
def __init__(self, module_list: Optional[List[str]] = None):
|
||
"""
|
||
初始化导入管理器
|
||
|
||
Args:
|
||
module_list: 要预加载的模块路径列表
|
||
"""
|
||
self._modules: Dict[str, Any] = {}
|
||
self._classes: Dict[str, Type] = {}
|
||
self._functions: Dict[str, Callable] = {}
|
||
self._search_miss: set = set()
|
||
|
||
if module_list:
|
||
for module_path in module_list:
|
||
self.load_module(module_path)
|
||
|
||
def load_module(self, module_path: str) -> Any:
|
||
"""
|
||
加载指定路径的模块
|
||
|
||
Args:
|
||
module_path: 模块路径
|
||
|
||
Returns:
|
||
加载的模块对象
|
||
|
||
Raises:
|
||
ImportError: 如果模块导入失败
|
||
"""
|
||
try:
|
||
if module_path in self._modules:
|
||
return self._modules[module_path]
|
||
|
||
module = importlib.import_module(module_path)
|
||
self._modules[module_path] = module
|
||
|
||
# 索引模块中的类和函数
|
||
for name, obj in inspect.getmembers(module):
|
||
if inspect.isclass(obj):
|
||
full_name = f"{module_path}.{name}"
|
||
self._classes[name] = obj
|
||
self._classes[full_name] = obj
|
||
elif inspect.isfunction(obj):
|
||
full_name = f"{module_path}.{name}"
|
||
self._functions[name] = obj
|
||
self._functions[full_name] = obj
|
||
|
||
return module
|
||
except Exception as e:
|
||
logger.error(f"导入模块 '{module_path}' 时发生错误:{str(e)}")
|
||
logger.warning(traceback.format_exc())
|
||
raise ImportError(f"无法导入模块 {module_path}: {str(e)}")
|
||
|
||
def get_module(self, module_path: str) -> Any:
|
||
"""
|
||
获取已加载的模块
|
||
|
||
Args:
|
||
module_path: 模块路径
|
||
|
||
Returns:
|
||
模块对象
|
||
|
||
Raises:
|
||
KeyError: 如果模块未加载
|
||
"""
|
||
if module_path not in self._modules:
|
||
return self.load_module(module_path)
|
||
return self._modules[module_path]
|
||
|
||
def get_class(self, class_name: str) -> Type:
|
||
"""
|
||
获取类对象
|
||
|
||
Args:
|
||
class_name: 类名或完整类路径
|
||
|
||
Returns:
|
||
类对象
|
||
|
||
Raises:
|
||
KeyError: 如果找不到类
|
||
"""
|
||
if class_name in self._classes:
|
||
return self._classes[class_name]
|
||
|
||
# 尝试动态导入
|
||
if ":" in class_name:
|
||
module_path, cls_name = class_name.rsplit(":", 1)
|
||
module = self.load_module(module_path)
|
||
if hasattr(module, cls_name):
|
||
cls = getattr(module, cls_name)
|
||
self._classes[class_name] = cls
|
||
self._classes[cls_name] = cls
|
||
return cls
|
||
else:
|
||
# 如果cls_name是builtins中的关键字,则返回对应类
|
||
if class_name in builtins.__dict__:
|
||
return builtins.__dict__[class_name]
|
||
|
||
raise KeyError(f"找不到类: {class_name}")
|
||
|
||
def list_modules(self) -> List[str]:
|
||
"""列出所有已加载的模块路径"""
|
||
return list(self._modules.keys())
|
||
|
||
def list_classes(self) -> List[str]:
|
||
"""列出所有已索引的类名"""
|
||
return list(self._classes.keys())
|
||
|
||
def list_functions(self) -> List[str]:
|
||
"""列出所有已索引的函数名"""
|
||
return list(self._functions.keys())
|
||
|
||
def search_class(self, class_name: str, search_lower=False) -> Optional[Type]:
|
||
"""
|
||
在所有已加载的模块中搜索特定类名
|
||
|
||
Args:
|
||
class_name: 要搜索的类名
|
||
search_lower: 以小写搜索
|
||
|
||
Returns:
|
||
找到的类对象,如果未找到则返回None
|
||
"""
|
||
if class_name in builtins.__dict__:
|
||
return builtins.__dict__[class_name]
|
||
if class_name in self._classes:
|
||
return self._classes[class_name]
|
||
|
||
cache_key = class_name.lower() if search_lower else class_name
|
||
if cache_key in self._search_miss:
|
||
return None
|
||
|
||
if search_lower:
|
||
classes = {name.lower(): obj for name, obj in self._classes.items()}
|
||
if class_name in classes:
|
||
return classes[class_name]
|
||
|
||
for module_path, module in self._modules.items():
|
||
for name, obj in inspect.getmembers(module):
|
||
if inspect.isclass(obj) and (
|
||
(name.lower() == class_name.lower()) if search_lower else (name == class_name)
|
||
):
|
||
self._classes[name] = obj
|
||
self._classes[f"{module_path}:{name}"] = obj
|
||
return obj
|
||
|
||
self._search_miss.add(cache_key)
|
||
return None
|
||
|
||
def get_enhanced_class_info(self, module_path: str, **_kwargs) -> Dict[str, Any]:
|
||
"""通过 AST 分析获取类的增强信息。
|
||
|
||
复用 ``ast_registry_scanner`` 的 ``_collect_imports`` / ``_extract_class_body``,
|
||
与 AST 扫描注册表完全一致。
|
||
|
||
Args:
|
||
module_path: 格式 ``"module.path:ClassName"``
|
||
|
||
Returns:
|
||
``{"module_path", "ast_analysis_success", "import_map",
|
||
"init_params", "status_methods", "action_methods"}``
|
||
"""
|
||
from unilabos.registry.ast_registry_scanner import (
|
||
_collect_imports,
|
||
_extract_class_body,
|
||
_filepath_to_module,
|
||
)
|
||
|
||
result: Dict[str, Any] = {
|
||
"module_path": module_path,
|
||
"ast_analysis_success": False,
|
||
"import_map": {},
|
||
"init_params": [],
|
||
"status_methods": {},
|
||
"action_methods": {},
|
||
}
|
||
|
||
module_name, class_name = module_path.rsplit(":", 1)
|
||
file_path = self._module_path_to_file_path(module_name)
|
||
if not file_path or not os.path.exists(file_path):
|
||
logger.warning(f"[ImportManager] 找不到模块文件: {module_name} -> {file_path}")
|
||
return result
|
||
|
||
try:
|
||
with open(file_path, "r", encoding="utf-8") as f:
|
||
tree = ast.parse(f.read(), filename=file_path)
|
||
except Exception as e:
|
||
logger.warning(f"[ImportManager] 解析文件 {file_path} 失败: {e}")
|
||
return result
|
||
|
||
# 推导 module dotted path → 构建 import_map
|
||
python_path = Path(file_path)
|
||
for sp in sorted(sys.path, key=len, reverse=True):
|
||
try:
|
||
Path(file_path).relative_to(sp)
|
||
python_path = Path(sp)
|
||
break
|
||
except ValueError:
|
||
continue
|
||
module_dotted = _filepath_to_module(Path(file_path), python_path)
|
||
import_map = _collect_imports(tree, module_dotted)
|
||
result["import_map"] = import_map
|
||
|
||
# 定位目标类 AST 节点
|
||
target_class = None
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||
target_class = node
|
||
break
|
||
|
||
if target_class is None:
|
||
logger.warning(f"[ImportManager] 在文件 {file_path} 中找不到类 {class_name}")
|
||
return result
|
||
|
||
body = _extract_class_body(target_class, import_map)
|
||
|
||
# 映射到统一字段名(与 registry.py complete_registry 消费端一致)
|
||
result["init_params"] = body.get("init_params", [])
|
||
result["status_methods"] = body.get("status_properties", {})
|
||
result["action_methods"] = {
|
||
k: {
|
||
"args": v.get("params", []),
|
||
"return_type": v.get("return_type", ""),
|
||
"is_async": v.get("is_async", False),
|
||
"always_free": v.get("always_free", False),
|
||
"docstring": v.get("docstring"),
|
||
}
|
||
for k, v in body.get("auto_methods", {}).items()
|
||
}
|
||
result["ast_analysis_success"] = True
|
||
return result
|
||
|
||
def _analyze_method_signature(self, method, skip_unilabos_params: bool = True) -> Dict[str, Any]:
|
||
"""
|
||
分析方法签名,提取具体的命名参数信息
|
||
|
||
注意:此方法会跳过*args和**kwargs,只提取具体的命名参数
|
||
这样可以确保通过**dict方式传参时的准确性
|
||
|
||
Args:
|
||
method: 要分析的方法
|
||
skip_unilabos_params: 是否跳过 unilabos 系统参数(如 sample_uuids),
|
||
registry 补全时为 True,JsonCommand 执行时为 False
|
||
|
||
示例用法:
|
||
method_info = self._analyze_method_signature(some_method)
|
||
params = {"param1": "value1", "param2": "value2"}
|
||
result = some_method(**params) # 安全的参数传递
|
||
"""
|
||
signature = inspect.signature(method)
|
||
args = []
|
||
num_required = 0
|
||
|
||
for param_name, param in signature.parameters.items():
|
||
# 跳过self参数
|
||
if param_name == "self":
|
||
continue
|
||
|
||
# 跳过*args和**kwargs参数
|
||
if param.kind == param.VAR_POSITIONAL: # *args
|
||
continue
|
||
if param.kind == param.VAR_KEYWORD: # **kwargs
|
||
continue
|
||
|
||
# 跳过 sample_uuids 参数(由系统自动注入,registry 补全时跳过)
|
||
if skip_unilabos_params and param_name == PARAM_SAMPLE_UUIDS:
|
||
continue
|
||
|
||
is_required = param.default == inspect.Parameter.empty
|
||
if is_required:
|
||
num_required += 1
|
||
|
||
args.append(
|
||
{
|
||
"name": param_name,
|
||
"type": self._get_type_string(param.annotation),
|
||
"required": is_required,
|
||
"default": None if param.default == inspect.Parameter.empty else param.default,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"name": method.__name__,
|
||
"args": args,
|
||
"return_type": self._get_type_string(signature.return_annotation),
|
||
"is_async": inspect.iscoroutinefunction(method),
|
||
}
|
||
|
||
def _get_return_type_from_method(self, method) -> Union[str, Tuple[str, Any]]:
|
||
"""从方法中获取返回类型"""
|
||
signature = inspect.signature(method)
|
||
return self._get_type_string(signature.return_annotation)
|
||
|
||
def _get_type_string(self, annotation) -> Union[str, Tuple[str, Any]]:
|
||
"""将类型注解转换为类型字符串。
|
||
|
||
非内建类返回 ``module:ClassName`` 全路径(如
|
||
``"unilabos.registry.placeholder_type:ResourceSlot"``),
|
||
避免短名冲突;内建类型直接返回短名(如 ``"str"``、``"int"``)。
|
||
"""
|
||
if annotation == inspect.Parameter.empty:
|
||
return "Any"
|
||
if annotation is None:
|
||
return "None"
|
||
if hasattr(annotation, "__origin__"):
|
||
origin = annotation.__origin__
|
||
if origin in (list, set, tuple):
|
||
if hasattr(annotation, "__args__") and annotation.__args__:
|
||
if len(annotation.__args__):
|
||
arg0 = annotation.__args__[0]
|
||
if isinstance(arg0, int):
|
||
return "Int64MultiArray"
|
||
elif isinstance(arg0, float):
|
||
return "Float64MultiArray"
|
||
return "list", self._get_type_string(arg0)
|
||
elif origin is dict:
|
||
return "dict"
|
||
elif origin is Optional:
|
||
return "Unknown"
|
||
return "Unknown"
|
||
annotation_str = str(annotation)
|
||
if "typing." in annotation_str:
|
||
return (
|
||
annotation_str.replace("typing.", "")
|
||
if getattr(annotation, "_name", None) is None
|
||
else annotation._name.lower()
|
||
)
|
||
if hasattr(annotation, "__name__"):
|
||
module = getattr(annotation, "__module__", None)
|
||
if module and module != "builtins":
|
||
return f"{module}:{annotation.__name__}"
|
||
return annotation.__name__
|
||
elif hasattr(annotation, "_name"):
|
||
return annotation._name
|
||
elif isinstance(annotation, str):
|
||
return annotation
|
||
else:
|
||
return annotation_str
|
||
|
||
def _module_path_to_file_path(self, module_path: str) -> Optional[str]:
|
||
for path in sys.path:
|
||
potential_path = Path(path) / module_path.replace(".", "/")
|
||
|
||
# 检查是否为包
|
||
if (potential_path / "__init__.py").exists():
|
||
return str(potential_path / "__init__.py")
|
||
|
||
# 检查是否为模块文件
|
||
if (potential_path.parent / f"{potential_path.name}.py").exists():
|
||
return str(potential_path.parent / f"{potential_path.name}.py")
|
||
|
||
return None
|
||
|
||
|
||
|
||
# 全局实例,便于直接使用
|
||
default_manager = ImportManager()
|
||
|
||
|
||
def load_module(module_path: str) -> Any:
|
||
"""加载模块的便捷函数"""
|
||
return default_manager.load_module(module_path)
|
||
|
||
|
||
def get_class(class_name: str) -> Type:
|
||
"""获取类的便捷函数"""
|
||
return default_manager.get_class(class_name)
|
||
|
||
|
||
def get_module(module_path: str) -> Any:
|
||
"""获取模块的便捷函数"""
|
||
return default_manager.get_module(module_path)
|
||
|
||
|
||
def init_from_list(module_list: List[str]) -> None:
|
||
"""从模块列表初始化默认管理器"""
|
||
global default_manager
|
||
default_manager = ImportManager(module_list)
|
||
|
||
|
||
def get_enhanced_class_info(module_path: str, **kwargs) -> Dict[str, Any]:
|
||
"""获取增强的类信息的便捷函数"""
|
||
return default_manager.get_enhanced_class_info(module_path, **kwargs)
|