""" 导入管理器 该模块提供了一个动态导入和管理模块的系统,避免误删未使用的导入。 """ import builtins import importlib import inspect import sys import traceback import ast import os from pathlib import Path from typing import Dict, List, Any, Optional, Callable, Type, Union, Tuple __all__ = [ "ImportManager", "default_manager", "load_module", "get_class", "get_module", "init_from_list", "get_enhanced_class_info", ] from unilabos.resources.resource_tracker import PARAM_SAMPLE_UUIDS from unilabos.utils import logger class ImportManager: """导入管理器类,用于动态加载和管理模块""" def __init__(self, module_list: Optional[List[str]] = None): """ 初始化导入管理器 Args: module_list: 要预加载的模块路径列表 """ self._modules: Dict[str, Any] = {} self._classes: Dict[str, Type] = {} self._functions: Dict[str, Callable] = {} if module_list: for module_path in module_list: self.load_module(module_path) def load_module(self, module_path: str) -> Any: """ 加载指定路径的模块 Args: module_path: 模块路径 Returns: 加载的模块对象 Raises: ImportError: 如果模块导入失败 """ try: if module_path in self._modules: return self._modules[module_path] module = importlib.import_module(module_path) self._modules[module_path] = module # 索引模块中的类和函数 for name, obj in inspect.getmembers(module): if inspect.isclass(obj): full_name = f"{module_path}.{name}" self._classes[name] = obj self._classes[full_name] = obj elif inspect.isfunction(obj): full_name = f"{module_path}.{name}" self._functions[name] = obj self._functions[full_name] = obj return module except Exception as e: logger.error(f"导入模块 '{module_path}' 时发生错误:{str(e)}") logger.warning(traceback.format_exc()) raise ImportError(f"无法导入模块 {module_path}: {str(e)}") def get_module(self, module_path: str) -> Any: """ 获取已加载的模块 Args: module_path: 模块路径 Returns: 模块对象 Raises: KeyError: 如果模块未加载 """ if module_path not in self._modules: return self.load_module(module_path) return self._modules[module_path] def get_class(self, class_name: str) -> Type: """ 获取类对象 Args: class_name: 类名或完整类路径 Returns: 类对象 Raises: KeyError: 如果找不到类 """ if class_name in self._classes: return self._classes[class_name] # 尝试动态导入 if ":" in class_name: module_path, cls_name = class_name.rsplit(":", 1) module = self.load_module(module_path) if hasattr(module, cls_name): cls = getattr(module, cls_name) self._classes[class_name] = cls self._classes[cls_name] = cls return cls else: # 如果cls_name是builtins中的关键字,则返回对应类 if class_name in builtins.__dict__: return builtins.__dict__[class_name] raise KeyError(f"找不到类: {class_name}") def list_modules(self) -> List[str]: """列出所有已加载的模块路径""" return list(self._modules.keys()) def list_classes(self) -> List[str]: """列出所有已索引的类名""" return list(self._classes.keys()) def list_functions(self) -> List[str]: """列出所有已索引的函数名""" return list(self._functions.keys()) def search_class(self, class_name: str, search_lower=False) -> Optional[Type]: """ 在所有已加载的模块中搜索特定类名 Args: class_name: 要搜索的类名 search_lower: 以小写搜索 Returns: 找到的类对象,如果未找到则返回None """ # 如果cls_name是builtins中的关键字,则返回对应类 if class_name in builtins.__dict__: return builtins.__dict__[class_name] # 首先在已索引的类中查找 if class_name in self._classes: return self._classes[class_name] if search_lower: classes = {name.lower(): obj for name, obj in self._classes.items()} if class_name in classes: return classes[class_name] # 遍历所有已加载的模块进行搜索 for module_path, module in self._modules.items(): for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and ( (name.lower() == class_name.lower()) if search_lower else (name == class_name) ): # 将找到的类添加到索引中 self._classes[name] = obj self._classes[f"{module_path}:{name}"] = obj return obj return None def get_enhanced_class_info(self, module_path: str, **_kwargs) -> Dict[str, Any]: """通过 AST 分析获取类的增强信息。 复用 ``ast_registry_scanner`` 的 ``_collect_imports`` / ``_extract_class_body``, 与 AST 扫描注册表完全一致。 Args: module_path: 格式 ``"module.path:ClassName"`` Returns: ``{"module_path", "ast_analysis_success", "import_map", "init_params", "status_methods", "action_methods"}`` """ from unilabos.registry.ast_registry_scanner import ( _collect_imports, _extract_class_body, _filepath_to_module, ) result: Dict[str, Any] = { "module_path": module_path, "ast_analysis_success": False, "import_map": {}, "init_params": [], "status_methods": {}, "action_methods": {}, } module_name, class_name = module_path.rsplit(":", 1) file_path = self._module_path_to_file_path(module_name) if not file_path or not os.path.exists(file_path): logger.warning(f"[ImportManager] 找不到模块文件: {module_name} -> {file_path}") return result try: with open(file_path, "r", encoding="utf-8") as f: tree = ast.parse(f.read(), filename=file_path) except Exception as e: logger.warning(f"[ImportManager] 解析文件 {file_path} 失败: {e}") return result # 推导 module dotted path → 构建 import_map python_path = Path(file_path) for sp in sorted(sys.path, key=len, reverse=True): try: Path(file_path).relative_to(sp) python_path = Path(sp) break except ValueError: continue module_dotted = _filepath_to_module(Path(file_path), python_path) import_map = _collect_imports(tree, module_dotted) result["import_map"] = import_map # 定位目标类 AST 节点 target_class = None for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == class_name: target_class = node break if target_class is None: logger.warning(f"[ImportManager] 在文件 {file_path} 中找不到类 {class_name}") return result body = _extract_class_body(target_class, import_map) # 映射到统一字段名(与 registry.py complete_registry 消费端一致) result["init_params"] = body.get("init_params", []) result["status_methods"] = body.get("status_properties", {}) result["action_methods"] = { k: { "args": v.get("params", []), "return_type": v.get("return_type", ""), "is_async": v.get("is_async", False), "always_free": v.get("always_free", False), "docstring": v.get("docstring"), } for k, v in body.get("auto_methods", {}).items() } result["ast_analysis_success"] = True return result def _analyze_method_signature(self, method, skip_unilabos_params: bool = True) -> Dict[str, Any]: """ 分析方法签名,提取具体的命名参数信息 注意:此方法会跳过*args和**kwargs,只提取具体的命名参数 这样可以确保通过**dict方式传参时的准确性 Args: method: 要分析的方法 skip_unilabos_params: 是否跳过 unilabos 系统参数(如 sample_uuids), registry 补全时为 True,JsonCommand 执行时为 False 示例用法: method_info = self._analyze_method_signature(some_method) params = {"param1": "value1", "param2": "value2"} result = some_method(**params) # 安全的参数传递 """ signature = inspect.signature(method) args = [] num_required = 0 for param_name, param in signature.parameters.items(): # 跳过self参数 if param_name == "self": continue # 跳过*args和**kwargs参数 if param.kind == param.VAR_POSITIONAL: # *args continue if param.kind == param.VAR_KEYWORD: # **kwargs continue # 跳过 sample_uuids 参数(由系统自动注入,registry 补全时跳过) if skip_unilabos_params and param_name == PARAM_SAMPLE_UUIDS: continue is_required = param.default == inspect.Parameter.empty if is_required: num_required += 1 args.append( { "name": param_name, "type": self._get_type_string(param.annotation), "required": is_required, "default": None if param.default == inspect.Parameter.empty else param.default, } ) return { "name": method.__name__, "args": args, "return_type": self._get_type_string(signature.return_annotation), "is_async": inspect.iscoroutinefunction(method), } def _get_return_type_from_method(self, method) -> Union[str, Tuple[str, Any]]: """从方法中获取返回类型""" signature = inspect.signature(method) return self._get_type_string(signature.return_annotation) def _get_type_string(self, annotation) -> Union[str, Tuple[str, Any]]: """将类型注解转换为短类名(与 AST _get_annotation_str 对齐)。 自定义类只返回短名(如 ``"SetLiquidReturn"``),完整路径由 ``import_map`` 负责解析,保持与 AST 分析一致。 """ if annotation == inspect.Parameter.empty: return "Any" if annotation is None: return "None" if hasattr(annotation, "__origin__"): origin = annotation.__origin__ if origin in (list, set, tuple): if hasattr(annotation, "__args__") and annotation.__args__: if len(annotation.__args__): arg0 = annotation.__args__[0] if isinstance(arg0, int): return "Int64MultiArray" elif isinstance(arg0, float): return "Float64MultiArray" return "list", self._get_type_string(arg0) elif origin is dict: return "dict" elif origin is Optional: return "Unknown" return "Unknown" annotation_str = str(annotation) if "typing." in annotation_str: return ( annotation_str.replace("typing.", "") if getattr(annotation, "_name", None) is None else annotation._name.lower() ) if hasattr(annotation, "__name__"): return annotation.__name__ elif hasattr(annotation, "_name"): return annotation._name elif isinstance(annotation, str): return annotation else: return annotation_str def _module_path_to_file_path(self, module_path: str) -> Optional[str]: for path in sys.path: potential_path = Path(path) / module_path.replace(".", "/") # 检查是否为包 if (potential_path / "__init__.py").exists(): return str(potential_path / "__init__.py") # 检查是否为模块文件 if (potential_path.parent / f"{potential_path.name}.py").exists(): return str(potential_path.parent / f"{potential_path.name}.py") return None # 全局实例,便于直接使用 default_manager = ImportManager() def load_module(module_path: str) -> Any: """加载模块的便捷函数""" return default_manager.load_module(module_path) def get_class(class_name: str) -> Type: """获取类的便捷函数""" return default_manager.get_class(class_name) def get_module(module_path: str) -> Any: """获取模块的便捷函数""" return default_manager.get_module(module_path) def init_from_list(module_list: List[str]) -> None: """从模块列表初始化默认管理器""" global default_manager default_manager = ImportManager(module_list) def get_enhanced_class_info(module_path: str, **kwargs) -> Dict[str, Any]: """获取增强的类信息的便捷函数""" return default_manager.get_enhanced_class_info(module_path, **kwargs)