mirror of
https://github.com/deepmodeling/Uni-Lab-OS
synced 2026-03-25 20:33:13 +00:00
Merge branch 'dev' into prcix9320
This commit is contained in:
@@ -19,74 +19,6 @@ def singleton(cls):
|
||||
return get_instance
|
||||
|
||||
|
||||
def topic_config(
|
||||
period: Optional[float] = None,
|
||||
print_publish: Optional[bool] = None,
|
||||
qos: Optional[int] = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""
|
||||
Topic发布配置装饰器
|
||||
|
||||
用于装饰 get_{attr_name} 方法或 @property,控制对应属性的ROS topic发布行为。
|
||||
|
||||
Args:
|
||||
period: 发布周期(秒)。None 表示使用默认值 5.0
|
||||
print_publish: 是否打印发布日志。None 表示使用节点默认配置
|
||||
qos: QoS深度配置。None 表示使用默认值 10
|
||||
|
||||
Example:
|
||||
class MyDriver:
|
||||
# 方式1: 装饰 get_{attr_name} 方法
|
||||
@topic_config(period=1.0, print_publish=False, qos=5)
|
||||
def get_temperature(self):
|
||||
return self._temperature
|
||||
|
||||
# 方式2: 与 @property 连用(topic_config 放在下面)
|
||||
@property
|
||||
@topic_config(period=0.1)
|
||||
def position(self):
|
||||
return self._position
|
||||
|
||||
Note:
|
||||
与 @property 连用时,@topic_config 必须放在 @property 下面,
|
||||
这样装饰器执行顺序为:先 topic_config 添加配置,再 property 包装。
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 在函数上附加配置属性 (type: ignore 用于动态属性)
|
||||
wrapper._topic_period = period # type: ignore[attr-defined]
|
||||
wrapper._topic_print_publish = print_publish # type: ignore[attr-defined]
|
||||
wrapper._topic_qos = qos # type: ignore[attr-defined]
|
||||
wrapper._has_topic_config = True # type: ignore[attr-defined]
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_topic_config(func) -> dict:
|
||||
"""
|
||||
获取函数上的topic配置
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
|
||||
Returns:
|
||||
包含 period, print_publish, qos 的配置字典
|
||||
"""
|
||||
if hasattr(func, "_has_topic_config") and getattr(func, "_has_topic_config", False):
|
||||
return {
|
||||
"period": getattr(func, "_topic_period", None),
|
||||
"print_publish": getattr(func, "_topic_print_publish", None),
|
||||
"qos": getattr(func, "_topic_qos", None),
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def subscribe(
|
||||
topic: str,
|
||||
msg_type: Optional[type] = None,
|
||||
@@ -104,24 +36,6 @@ def subscribe(
|
||||
- {namespace}: 完整命名空间 (如 "/devices/pump_1")
|
||||
msg_type: ROS 消息类型。如果为 None,需要在回调函数的类型注解中指定
|
||||
qos: QoS 深度配置,默认为 10
|
||||
|
||||
Example:
|
||||
from std_msgs.msg import String, Float64
|
||||
|
||||
class MyDriver:
|
||||
@subscribe(topic="/devices/{device_id}/set_speed", msg_type=Float64)
|
||||
def on_speed_update(self, msg: Float64):
|
||||
self._speed = msg.data
|
||||
print(f"Speed updated to: {self._speed}")
|
||||
|
||||
@subscribe(topic="{namespace}/command")
|
||||
def on_command(self, msg: String):
|
||||
# msg_type 可从类型注解推断
|
||||
self.execute_command(msg.data)
|
||||
|
||||
Note:
|
||||
- 回调方法的第一个参数是 self,第二个参数是收到的 ROS 消息
|
||||
- topic 中的占位符会在创建订阅时被实际值替换
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@@ -129,7 +43,6 @@ def subscribe(
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 在函数上附加订阅配置
|
||||
wrapper._subscribe_topic = topic # type: ignore[attr-defined]
|
||||
wrapper._subscribe_msg_type = msg_type # type: ignore[attr-defined]
|
||||
wrapper._subscribe_qos = qos # type: ignore[attr-defined]
|
||||
@@ -141,15 +54,7 @@ def subscribe(
|
||||
|
||||
|
||||
def get_subscribe_config(func) -> dict:
|
||||
"""
|
||||
获取函数上的订阅配置
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
|
||||
Returns:
|
||||
包含 topic, msg_type, qos 的配置字典
|
||||
"""
|
||||
"""获取函数上的订阅配置 (topic, msg_type, qos)"""
|
||||
if hasattr(func, "_has_subscribe") and getattr(func, "_has_subscribe", False):
|
||||
return {
|
||||
"topic": getattr(func, "_subscribe_topic", None),
|
||||
@@ -163,9 +68,6 @@ def get_all_subscriptions(instance) -> list:
|
||||
"""
|
||||
扫描实例的所有方法,获取带有 @subscribe 装饰器的方法及其配置
|
||||
|
||||
Args:
|
||||
instance: 要扫描的实例
|
||||
|
||||
Returns:
|
||||
包含 (method_name, method, config) 元组的列表
|
||||
"""
|
||||
@@ -184,92 +86,14 @@ def get_all_subscriptions(instance) -> list:
|
||||
return subscriptions
|
||||
|
||||
|
||||
def always_free(func: F) -> F:
|
||||
"""
|
||||
标记动作为永久闲置(不受busy队列限制)的装饰器
|
||||
|
||||
被此装饰器标记的 action 方法,在执行时不会受到设备级别的排队限制,
|
||||
任何时候请求都可以立即执行。适用于查询类、状态读取类等轻量级操作。
|
||||
|
||||
Example:
|
||||
class MyDriver:
|
||||
@always_free
|
||||
def query_status(self, param: str):
|
||||
# 这个动作可以随时执行,不需要排队
|
||||
return self._status
|
||||
|
||||
def transfer(self, volume: float):
|
||||
# 这个动作会按正常排队逻辑执行
|
||||
pass
|
||||
|
||||
Note:
|
||||
- 可以与其他装饰器组合使用,@always_free 应放在最外层
|
||||
- 仅影响 WebSocket 调度层的 busy/free 判断,不影响 ROS2 层
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper._is_always_free = True # type: ignore[attr-defined]
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def is_always_free(func) -> bool:
|
||||
"""
|
||||
检查函数是否被标记为永久闲置
|
||||
|
||||
Args:
|
||||
func: 被检查的函数
|
||||
|
||||
Returns:
|
||||
如果函数被 @always_free 装饰则返回 True,否则返回 False
|
||||
"""
|
||||
return getattr(func, "_is_always_free", False)
|
||||
|
||||
|
||||
def not_action(func: F) -> F:
|
||||
"""
|
||||
标记方法为非动作的装饰器
|
||||
|
||||
用于装饰 driver 类中的方法,使其在 complete_registry 时不被识别为动作。
|
||||
适用于辅助方法、内部工具方法等不应暴露为设备动作的公共方法。
|
||||
|
||||
Example:
|
||||
class MyDriver:
|
||||
@not_action
|
||||
def helper_method(self):
|
||||
# 这个方法不会被注册为动作
|
||||
pass
|
||||
|
||||
def actual_action(self, param: str):
|
||||
# 这个方法会被注册为动作
|
||||
self.helper_method()
|
||||
|
||||
Note:
|
||||
- 可以与其他装饰器组合使用,@not_action 应放在最外层
|
||||
- 仅影响 complete_registry 的动作识别,不影响方法的正常调用
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# 在函数上附加标记
|
||||
wrapper._is_not_action = True # type: ignore[attr-defined]
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def is_not_action(func) -> bool:
|
||||
"""
|
||||
检查函数是否被标记为非动作
|
||||
|
||||
Args:
|
||||
func: 被检查的函数
|
||||
|
||||
Returns:
|
||||
如果函数被 @not_action 装饰则返回 True,否则返回 False
|
||||
"""
|
||||
return getattr(func, "_is_not_action", False)
|
||||
# ---------------------------------------------------------------------------
|
||||
# 向后兼容重导出 -- 已迁移到 unilabos.registry.decorators
|
||||
# ---------------------------------------------------------------------------
|
||||
from unilabos.registry.decorators import ( # noqa: E402, F401
|
||||
topic_config,
|
||||
get_topic_config,
|
||||
always_free,
|
||||
is_always_free,
|
||||
not_action,
|
||||
is_not_action,
|
||||
)
|
||||
|
||||
@@ -6,54 +6,199 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import locale
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from unilabos.utils.banner_print import print_status
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 底层安装工具
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _is_chinese_locale() -> bool:
|
||||
try:
|
||||
lang = locale.getdefaultlocale()[0]
|
||||
return bool(lang and ("zh" in lang.lower() or "chinese" in lang.lower()))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_USE_UV: Optional[bool] = None
|
||||
|
||||
|
||||
def _has_uv() -> bool:
|
||||
global _USE_UV
|
||||
if _USE_UV is None:
|
||||
_USE_UV = shutil.which("uv") is not None
|
||||
return _USE_UV
|
||||
|
||||
|
||||
def _install_packages(
|
||||
packages: List[str],
|
||||
upgrade: bool = False,
|
||||
label: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
安装/升级一组包。优先 uv pip install,回退 sys pip。
|
||||
逐个安装,任意一个失败不影响后续包。
|
||||
|
||||
Returns:
|
||||
True if all succeeded, False otherwise.
|
||||
"""
|
||||
if not packages:
|
||||
return True
|
||||
|
||||
is_chinese = _is_chinese_locale()
|
||||
use_uv = _has_uv()
|
||||
failed: List[str] = []
|
||||
|
||||
for pkg in packages:
|
||||
action_word = "升级" if upgrade else "安装"
|
||||
if label:
|
||||
print_status(f"[{label}] 正在{action_word} {pkg}...", "info")
|
||||
else:
|
||||
print_status(f"正在{action_word} {pkg}...", "info")
|
||||
|
||||
if use_uv:
|
||||
cmd = ["uv", "pip", "install"]
|
||||
if upgrade:
|
||||
cmd.append("--upgrade")
|
||||
cmd.append(pkg)
|
||||
if is_chinese:
|
||||
cmd.extend(["--index-url", "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"])
|
||||
else:
|
||||
cmd = [sys.executable, "-m", "pip", "install"]
|
||||
if upgrade:
|
||||
cmd.append("--upgrade")
|
||||
cmd.append(pkg)
|
||||
if is_chinese:
|
||||
cmd.extend(["-i", "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
if result.returncode == 0:
|
||||
installer = "uv" if use_uv else "pip"
|
||||
print_status(f"✓ {pkg} {action_word}成功 (via {installer})", "success")
|
||||
else:
|
||||
stderr_short = result.stderr.strip().split("\n")[-1] if result.stderr else "unknown error"
|
||||
print_status(f"× {pkg} {action_word}失败: {stderr_short}", "error")
|
||||
failed.append(pkg)
|
||||
except subprocess.TimeoutExpired:
|
||||
print_status(f"× {pkg} {action_word}超时 (300s)", "error")
|
||||
failed.append(pkg)
|
||||
except Exception as e:
|
||||
print_status(f"× {pkg} {action_word}异常: {e}", "error")
|
||||
failed.append(pkg)
|
||||
|
||||
if failed:
|
||||
print_status(f"有 {len(failed)} 个包操作失败: {', '.join(failed)}", "error")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# requirements.txt 安装(可多次调用)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def install_requirements_txt(req_path: str | Path, label: str = "") -> bool:
|
||||
"""
|
||||
读取一个 requirements.txt 文件,检查缺失的包并安装。
|
||||
|
||||
Args:
|
||||
req_path: requirements.txt 文件路径
|
||||
label: 日志前缀标签(如 "device_package_sim")
|
||||
|
||||
Returns:
|
||||
True if all ok, False if any install failed.
|
||||
"""
|
||||
req_path = Path(req_path)
|
||||
if not req_path.exists():
|
||||
return True
|
||||
|
||||
tag = label or req_path.parent.name
|
||||
print_status(f"[{tag}] 检查依赖: {req_path}", "info")
|
||||
|
||||
reqs: List[str] = []
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and not line.startswith("-"):
|
||||
reqs.append(line)
|
||||
|
||||
if not reqs:
|
||||
return True
|
||||
|
||||
missing: List[str] = []
|
||||
for req in reqs:
|
||||
pkg_import = req.split(">=")[0].split("==")[0].split("<")[0].split("[")[0].split(">")[0].strip()
|
||||
pkg_import = pkg_import.replace("-", "_")
|
||||
try:
|
||||
importlib.import_module(pkg_import)
|
||||
except ImportError:
|
||||
missing.append(req)
|
||||
|
||||
if not missing:
|
||||
print_status(f"[{tag}] ✓ 依赖检查通过 ({len(reqs)} 个包)", "success")
|
||||
return True
|
||||
|
||||
print_status(f"[{tag}] 缺失 {len(missing)} 个依赖: {', '.join(missing)}", "warning")
|
||||
return _install_packages(missing, label=tag)
|
||||
|
||||
|
||||
def check_device_package_requirements(devices_dirs: list[str]) -> bool:
|
||||
"""
|
||||
检查 --devices 指定的所有外部设备包目录中的 requirements.txt。
|
||||
对每个目录查找 requirements.txt(先在目录内找,再在父目录找)。
|
||||
"""
|
||||
if not devices_dirs:
|
||||
return True
|
||||
|
||||
all_ok = True
|
||||
for d in devices_dirs:
|
||||
d_path = Path(d).resolve()
|
||||
req_file = d_path / "requirements.txt"
|
||||
if not req_file.exists():
|
||||
req_file = d_path.parent / "requirements.txt"
|
||||
if not req_file.exists():
|
||||
continue
|
||||
if not install_requirements_txt(req_file, label=d_path.name):
|
||||
all_ok = False
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniLabOS 核心环境检查
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EnvironmentChecker:
|
||||
"""环境检查器"""
|
||||
|
||||
def __init__(self):
|
||||
# 定义必需的包及其安装名称的映射
|
||||
self.required_packages = {
|
||||
# 包导入名 : pip安装名
|
||||
# "pymodbus.framer.FramerType": "pymodbus==3.9.2",
|
||||
"websockets": "websockets",
|
||||
"msgcenterpy": "msgcenterpy",
|
||||
"orjson": "orjson",
|
||||
"opentrons_shared_data": "opentrons_shared_data",
|
||||
"typing_extensions": "typing_extensions",
|
||||
"crcmod": "crcmod-plus",
|
||||
}
|
||||
|
||||
# 特殊安装包(需要特殊处理的包)
|
||||
self.special_packages = {"pylabrobot": "git+https://github.com/Xuwznln/pylabrobot.git"}
|
||||
|
||||
# 包版本要求(包名: 最低版本)
|
||||
self.version_requirements = {
|
||||
"msgcenterpy": "0.1.5", # msgcenterpy 最低版本要求
|
||||
"msgcenterpy": "0.1.8",
|
||||
}
|
||||
|
||||
self.missing_packages = []
|
||||
self.failed_installs = []
|
||||
self.packages_need_upgrade = []
|
||||
|
||||
# 检测系统语言
|
||||
self.is_chinese = self._is_chinese_locale()
|
||||
|
||||
def _is_chinese_locale(self) -> bool:
|
||||
"""检测系统是否为中文环境"""
|
||||
try:
|
||||
lang = locale.getdefaultlocale()[0]
|
||||
if lang and ("zh" in lang.lower() or "chinese" in lang.lower()):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
self.missing_packages: List[tuple] = []
|
||||
self.failed_installs: List[tuple] = []
|
||||
self.packages_need_upgrade: List[tuple] = []
|
||||
|
||||
def check_package_installed(self, package_name: str) -> bool:
|
||||
"""检查包是否已安装"""
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
return True
|
||||
@@ -61,7 +206,6 @@ class EnvironmentChecker:
|
||||
return False
|
||||
|
||||
def get_package_version(self, package_name: str) -> str | None:
|
||||
"""获取已安装包的版本"""
|
||||
try:
|
||||
module = importlib.import_module(package_name)
|
||||
return getattr(module, "__version__", None)
|
||||
@@ -69,88 +213,32 @@ class EnvironmentChecker:
|
||||
return None
|
||||
|
||||
def compare_version(self, current: str, required: str) -> bool:
|
||||
"""
|
||||
比较版本号
|
||||
Returns:
|
||||
True: current >= required
|
||||
False: current < required
|
||||
"""
|
||||
try:
|
||||
current_parts = [int(x) for x in current.split(".")]
|
||||
required_parts = [int(x) for x in required.split(".")]
|
||||
|
||||
# 补齐长度
|
||||
max_len = max(len(current_parts), len(required_parts))
|
||||
current_parts.extend([0] * (max_len - len(current_parts)))
|
||||
required_parts.extend([0] * (max_len - len(required_parts)))
|
||||
|
||||
return current_parts >= required_parts
|
||||
except Exception:
|
||||
return True # 如果无法比较,假设版本满足要求
|
||||
|
||||
def install_package(self, package_name: str, pip_name: str, upgrade: bool = False) -> bool:
|
||||
"""安装包"""
|
||||
try:
|
||||
action = "升级" if upgrade else "安装"
|
||||
print_status(f"正在{action} {package_name} ({pip_name})...", "info")
|
||||
|
||||
# 构建安装命令
|
||||
cmd = [sys.executable, "-m", "pip", "install"]
|
||||
|
||||
# 如果是升级操作,添加 --upgrade 参数
|
||||
if upgrade:
|
||||
cmd.append("--upgrade")
|
||||
|
||||
cmd.append(pip_name)
|
||||
|
||||
# 如果是中文环境,使用清华镜像源
|
||||
if self.is_chinese:
|
||||
cmd.extend(["-i", "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"])
|
||||
|
||||
# 执行安装
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) # 5分钟超时
|
||||
|
||||
if result.returncode == 0:
|
||||
print_status(f"✓ {package_name} {action}成功", "success")
|
||||
return True
|
||||
else:
|
||||
print_status(f"× {package_name} {action}失败: {result.stderr}", "error")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print_status(f"× {package_name} {action}超时", "error")
|
||||
return False
|
||||
except Exception as e:
|
||||
print_status(f"× {package_name} {action}异常: {str(e)}", "error")
|
||||
return False
|
||||
|
||||
def upgrade_package(self, package_name: str, pip_name: str) -> bool:
|
||||
"""升级包"""
|
||||
return self.install_package(package_name, pip_name, upgrade=True)
|
||||
return True
|
||||
|
||||
def check_all_packages(self) -> bool:
|
||||
"""检查所有必需的包"""
|
||||
print_status("开始检查环境依赖...", "info")
|
||||
|
||||
# 检查常规包
|
||||
for import_name, pip_name in self.required_packages.items():
|
||||
if not self.check_package_installed(import_name):
|
||||
self.missing_packages.append((import_name, pip_name))
|
||||
else:
|
||||
# 检查版本要求
|
||||
if import_name in self.version_requirements:
|
||||
current_version = self.get_package_version(import_name)
|
||||
required_version = self.version_requirements[import_name]
|
||||
elif import_name in self.version_requirements:
|
||||
current_version = self.get_package_version(import_name)
|
||||
required_version = self.version_requirements[import_name]
|
||||
if current_version and not self.compare_version(current_version, required_version):
|
||||
print_status(
|
||||
f"{import_name} 版本过低 (当前: {current_version}, 需要: >={required_version})",
|
||||
"warning",
|
||||
)
|
||||
self.packages_need_upgrade.append((import_name, pip_name))
|
||||
|
||||
if current_version:
|
||||
if not self.compare_version(current_version, required_version):
|
||||
print_status(
|
||||
f"{import_name} 版本过低 (当前: {current_version}, 需要: >={required_version})",
|
||||
"warning",
|
||||
)
|
||||
self.packages_need_upgrade.append((import_name, pip_name))
|
||||
|
||||
# 检查特殊包
|
||||
for package_name, install_url in self.special_packages.items():
|
||||
if not self.check_package_installed(package_name):
|
||||
self.missing_packages.append((package_name, install_url))
|
||||
@@ -169,7 +257,6 @@ class EnvironmentChecker:
|
||||
return False
|
||||
|
||||
def install_missing_packages(self, auto_install: bool = True) -> bool:
|
||||
"""安装缺失的包"""
|
||||
if not self.missing_packages and not self.packages_need_upgrade:
|
||||
return True
|
||||
|
||||
@@ -177,62 +264,36 @@ class EnvironmentChecker:
|
||||
if self.missing_packages:
|
||||
print_status("缺失以下包:", "warning")
|
||||
for import_name, pip_name in self.missing_packages:
|
||||
print_status(f" - {import_name} (pip install {pip_name})", "warning")
|
||||
print_status(f" - {import_name} ({pip_name})", "warning")
|
||||
if self.packages_need_upgrade:
|
||||
print_status("需要升级以下包:", "warning")
|
||||
for import_name, pip_name in self.packages_need_upgrade:
|
||||
print_status(f" - {import_name} (pip install --upgrade {pip_name})", "warning")
|
||||
print_status(f" - {import_name} ({pip_name})", "warning")
|
||||
return False
|
||||
|
||||
# 安装缺失的包
|
||||
if self.missing_packages:
|
||||
print_status(f"开始自动安装 {len(self.missing_packages)} 个缺失的包...", "info")
|
||||
pkgs = [pip_name for _, pip_name in self.missing_packages]
|
||||
if not _install_packages(pkgs, label="unilabos"):
|
||||
self.failed_installs.extend(self.missing_packages)
|
||||
|
||||
success_count = 0
|
||||
for import_name, pip_name in self.missing_packages:
|
||||
if self.install_package(import_name, pip_name):
|
||||
success_count += 1
|
||||
else:
|
||||
self.failed_installs.append((import_name, pip_name))
|
||||
|
||||
print_status(f"✓ 成功安装 {success_count}/{len(self.missing_packages)} 个包", "success")
|
||||
|
||||
# 升级需要更新的包
|
||||
if self.packages_need_upgrade:
|
||||
print_status(f"开始自动升级 {len(self.packages_need_upgrade)} 个包...", "info")
|
||||
pkgs = [pip_name for _, pip_name in self.packages_need_upgrade]
|
||||
if not _install_packages(pkgs, upgrade=True, label="unilabos"):
|
||||
self.failed_installs.extend(self.packages_need_upgrade)
|
||||
|
||||
upgrade_success_count = 0
|
||||
for import_name, pip_name in self.packages_need_upgrade:
|
||||
if self.upgrade_package(import_name, pip_name):
|
||||
upgrade_success_count += 1
|
||||
else:
|
||||
self.failed_installs.append((import_name, pip_name))
|
||||
|
||||
print_status(f"✓ 成功升级 {upgrade_success_count}/{len(self.packages_need_upgrade)} 个包", "success")
|
||||
|
||||
if self.failed_installs:
|
||||
print_status(f"有 {len(self.failed_installs)} 个包操作失败:", "error")
|
||||
for import_name, pip_name in self.failed_installs:
|
||||
print_status(f" - {import_name} ({pip_name})", "error")
|
||||
return False
|
||||
|
||||
return True
|
||||
return not self.failed_installs
|
||||
|
||||
def verify_installation(self) -> bool:
|
||||
"""验证安装结果"""
|
||||
if not self.missing_packages and not self.packages_need_upgrade:
|
||||
return True
|
||||
|
||||
print_status("验证安装结果...", "info")
|
||||
|
||||
failed_verification = []
|
||||
|
||||
# 验证新安装的包
|
||||
for import_name, pip_name in self.missing_packages:
|
||||
if not self.check_package_installed(import_name):
|
||||
failed_verification.append((import_name, pip_name))
|
||||
|
||||
# 验证升级的包
|
||||
for import_name, pip_name in self.packages_need_upgrade:
|
||||
if not self.check_package_installed(import_name):
|
||||
failed_verification.append((import_name, pip_name))
|
||||
@@ -269,17 +330,14 @@ def check_environment(auto_install: bool = True, show_details: bool = True) -> b
|
||||
"""
|
||||
checker = EnvironmentChecker()
|
||||
|
||||
# 检查包
|
||||
if checker.check_all_packages():
|
||||
return True
|
||||
|
||||
# 安装缺失的包
|
||||
if not checker.install_missing_packages(auto_install):
|
||||
if show_details:
|
||||
print_status("请手动安装缺失的包后重新启动程序", "error")
|
||||
return False
|
||||
|
||||
# 验证安装
|
||||
if not checker.verify_installation():
|
||||
if show_details:
|
||||
print_status("安装验证失败,请检查网络连接或手动安装", "error")
|
||||
@@ -289,14 +347,12 @@ def check_environment(auto_install: bool = True, show_details: bool = True) -> b
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 命令行参数解析
|
||||
parser = argparse.ArgumentParser(description="UniLabOS 环境依赖检查工具")
|
||||
parser.add_argument("--no-auto-install", action="store_true", help="仅检查环境,不自动安装缺失的包")
|
||||
parser.add_argument("--silent", action="store_true", help="静默模式,不显示详细信息")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 执行环境检查
|
||||
auto_install = not args.no_auto_install
|
||||
show_details = not args.silent
|
||||
|
||||
|
||||
@@ -21,15 +21,11 @@ __all__ = [
|
||||
"get_class",
|
||||
"get_module",
|
||||
"init_from_list",
|
||||
"get_class_info_static",
|
||||
"get_registry_class_info",
|
||||
"get_enhanced_class_info",
|
||||
]
|
||||
|
||||
from ast import Constant
|
||||
|
||||
from unilabos.resources.resource_tracker import PARAM_SAMPLE_UUIDS
|
||||
from unilabos.utils import logger
|
||||
from unilabos.utils.decorator import is_not_action, is_always_free
|
||||
|
||||
|
||||
class ImportManager:
|
||||
@@ -45,6 +41,7 @@ class ImportManager:
|
||||
self._modules: Dict[str, Any] = {}
|
||||
self._classes: Dict[str, Type] = {}
|
||||
self._functions: Dict[str, Callable] = {}
|
||||
self._search_miss: set = set()
|
||||
|
||||
if module_list:
|
||||
for module_path in module_list:
|
||||
@@ -159,187 +156,113 @@ class ImportManager:
|
||||
Returns:
|
||||
找到的类对象,如果未找到则返回None
|
||||
"""
|
||||
# 如果cls_name是builtins中的关键字,则返回对应类
|
||||
if class_name in builtins.__dict__:
|
||||
return builtins.__dict__[class_name]
|
||||
# 首先在已索引的类中查找
|
||||
if class_name in self._classes:
|
||||
return self._classes[class_name]
|
||||
|
||||
cache_key = class_name.lower() if search_lower else class_name
|
||||
if cache_key in self._search_miss:
|
||||
return None
|
||||
|
||||
if search_lower:
|
||||
classes = {name.lower(): obj for name, obj in self._classes.items()}
|
||||
if class_name in classes:
|
||||
return classes[class_name]
|
||||
|
||||
# 遍历所有已加载的模块进行搜索
|
||||
for module_path, module in self._modules.items():
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and (
|
||||
(name.lower() == class_name.lower()) if search_lower else (name == class_name)
|
||||
):
|
||||
# 将找到的类添加到索引中
|
||||
self._classes[name] = obj
|
||||
self._classes[f"{module_path}:{name}"] = obj
|
||||
return obj
|
||||
|
||||
self._search_miss.add(cache_key)
|
||||
return None
|
||||
|
||||
def get_enhanced_class_info(self, module_path: str, use_dynamic: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
获取增强的类信息,支持动态导入和静态分析
|
||||
def get_enhanced_class_info(self, module_path: str, **_kwargs) -> Dict[str, Any]:
|
||||
"""通过 AST 分析获取类的增强信息。
|
||||
|
||||
复用 ``ast_registry_scanner`` 的 ``_collect_imports`` / ``_extract_class_body``,
|
||||
与 AST 扫描注册表完全一致。
|
||||
|
||||
Args:
|
||||
module_path: 模块路径,格式为 "module.path" 或 "module.path:ClassName"
|
||||
use_dynamic: 是否优先使用动态导入
|
||||
module_path: 格式 ``"module.path:ClassName"``
|
||||
|
||||
Returns:
|
||||
包含详细类信息的字典
|
||||
``{"module_path", "ast_analysis_success", "import_map",
|
||||
"init_params", "status_methods", "action_methods"}``
|
||||
"""
|
||||
result = {
|
||||
from unilabos.registry.ast_registry_scanner import (
|
||||
_collect_imports,
|
||||
_extract_class_body,
|
||||
_filepath_to_module,
|
||||
)
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"module_path": module_path,
|
||||
"dynamic_import_success": False,
|
||||
"static_analysis_success": False,
|
||||
"init_params": {},
|
||||
"status_methods": {}, # get_ 开头和 @property 方法
|
||||
"action_methods": {}, # set_ 开头和其他非_开头方法
|
||||
}
|
||||
|
||||
# 尝试动态导入
|
||||
dynamic_info = None
|
||||
static_info = None
|
||||
if use_dynamic:
|
||||
try:
|
||||
dynamic_info = self._get_dynamic_class_info(module_path)
|
||||
result["dynamic_import_success"] = True
|
||||
logger.debug(f"[ImportManager] 动态导入类 {module_path} 成功")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[UniLab Registry] 在补充注册表时,动态导入类 "
|
||||
f"{module_path} 失败(将使用静态分析,"
|
||||
f"建议修复导入错误,以实现更好的注册表识别效果!): {e}"
|
||||
)
|
||||
use_dynamic = False
|
||||
if not use_dynamic:
|
||||
# 尝试静态分析
|
||||
try:
|
||||
static_info = self._get_static_class_info(module_path)
|
||||
result["static_analysis_success"] = True
|
||||
logger.debug(f"[ImportManager] 静态分析类 {module_path} 成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ImportManager] 静态分析类 {module_path} 失败: {e}")
|
||||
|
||||
# 合并信息(优先使用动态导入的信息)
|
||||
if dynamic_info:
|
||||
result.update(dynamic_info)
|
||||
elif static_info:
|
||||
result.update(static_info)
|
||||
|
||||
return result
|
||||
|
||||
def _get_dynamic_class_info(self, class_path: str) -> Dict[str, Any]:
|
||||
"""使用inspect模块动态获取类信息"""
|
||||
cls = get_class(class_path)
|
||||
class_name = cls.__name__
|
||||
|
||||
result = {
|
||||
"class_name": class_name,
|
||||
"init_params": self._analyze_method_signature(cls.__init__)["args"],
|
||||
"ast_analysis_success": False,
|
||||
"import_map": {},
|
||||
"init_params": [],
|
||||
"status_methods": {},
|
||||
"action_methods": {},
|
||||
}
|
||||
# 分析类的所有成员
|
||||
for name, method in cls.__dict__.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# 检查是否是property
|
||||
if isinstance(method, property):
|
||||
# @property 装饰的方法
|
||||
# noinspection PyTypeChecker
|
||||
return_type = self._get_return_type_from_method(method.fget) if method.fget else "Any"
|
||||
prop_info = {
|
||||
"name": name,
|
||||
"return_type": return_type,
|
||||
}
|
||||
result["status_methods"][name] = prop_info
|
||||
|
||||
# 检查是否有对应的setter
|
||||
if method.fset:
|
||||
setter_info = self._analyze_method_signature(method.fset)
|
||||
result["action_methods"][name] = setter_info
|
||||
|
||||
elif inspect.ismethod(method) or inspect.isfunction(method):
|
||||
if name.startswith("get_"):
|
||||
actual_name = name[4:] # 去掉get_前缀
|
||||
if actual_name in result["status_methods"]:
|
||||
continue
|
||||
# get_ 开头的方法归类为status
|
||||
method_info = self._analyze_method_signature(method)
|
||||
result["status_methods"][actual_name] = method_info
|
||||
elif not name.startswith("_"):
|
||||
# 其他非_开头的方法归类为action
|
||||
method_info = self._analyze_method_signature(method)
|
||||
# 检查是否被 @always_free 装饰器标记
|
||||
if is_always_free(method):
|
||||
method_info["always_free"] = True
|
||||
result["action_methods"][name] = method_info
|
||||
|
||||
return result
|
||||
|
||||
def _get_static_class_info(self, module_path: str) -> Dict[str, Any]:
|
||||
"""使用AST静态分析获取类信息"""
|
||||
module_name, class_name = module_path.rsplit(":", 1)
|
||||
# 将模块路径转换为文件路径
|
||||
file_path = self._module_path_to_file_path(module_name)
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"找不到模块文件: {module_name} -> {file_path}")
|
||||
logger.warning(f"[ImportManager] 找不到模块文件: {module_name} -> {file_path}")
|
||||
return result
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
tree = ast.parse(f.read(), filename=file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ImportManager] 解析文件 {file_path} 失败: {e}")
|
||||
return result
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
# 推导 module dotted path → 构建 import_map
|
||||
python_path = Path(file_path)
|
||||
for sp in sorted(sys.path, key=len, reverse=True):
|
||||
try:
|
||||
Path(file_path).relative_to(sp)
|
||||
python_path = Path(sp)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
module_dotted = _filepath_to_module(Path(file_path), python_path)
|
||||
import_map = _collect_imports(tree, module_dotted)
|
||||
result["import_map"] = import_map
|
||||
|
||||
# 查找目标类
|
||||
# 定位目标类 AST 节点
|
||||
target_class = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
if node.name == class_name:
|
||||
target_class = node
|
||||
break
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
target_class = node
|
||||
break
|
||||
|
||||
if target_class is None:
|
||||
raise AttributeError(f"在文件 {file_path} 中找不到类 {class_name}")
|
||||
logger.warning(f"[ImportManager] 在文件 {file_path} 中找不到类 {class_name}")
|
||||
return result
|
||||
|
||||
result = {
|
||||
"class_name": class_name,
|
||||
"init_params": {},
|
||||
"status_methods": {},
|
||||
"action_methods": {},
|
||||
body = _extract_class_body(target_class, import_map)
|
||||
|
||||
# 映射到统一字段名(与 registry.py complete_registry 消费端一致)
|
||||
result["init_params"] = body.get("init_params", [])
|
||||
result["status_methods"] = body.get("status_properties", {})
|
||||
result["action_methods"] = {
|
||||
k: {
|
||||
"args": v.get("params", []),
|
||||
"return_type": v.get("return_type", ""),
|
||||
"is_async": v.get("is_async", False),
|
||||
"always_free": v.get("always_free", False),
|
||||
"docstring": v.get("docstring"),
|
||||
}
|
||||
for k, v in body.get("auto_methods", {}).items()
|
||||
}
|
||||
|
||||
# 分析类的方法
|
||||
for node in target_class.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
method_info = self._analyze_method_node(node)
|
||||
method_name = node.name
|
||||
if method_name == "__init__":
|
||||
result["init_params"] = method_info["args"]
|
||||
elif method_name.startswith("_"):
|
||||
continue
|
||||
elif self._is_property_method(node):
|
||||
# @property 装饰的方法
|
||||
result["status_methods"][method_name] = method_info
|
||||
elif method_name.startswith("get_"):
|
||||
# get_ 开头的方法归类为status
|
||||
actual_name = method_name[4:] # 去掉get_前缀
|
||||
if actual_name not in result["status_methods"]:
|
||||
result["status_methods"][actual_name] = method_info
|
||||
else:
|
||||
# 其他非_开头的方法归类为action
|
||||
# 检查是否被 @always_free 装饰器标记
|
||||
if self._is_always_free_method(node):
|
||||
method_info["always_free"] = True
|
||||
result["action_methods"][method_name] = method_info
|
||||
result["ast_analysis_success"] = True
|
||||
return result
|
||||
|
||||
def _analyze_method_signature(self, method, skip_unilabos_params: bool = True) -> Dict[str, Any]:
|
||||
@@ -395,23 +318,26 @@ class ImportManager:
|
||||
"name": method.__name__,
|
||||
"args": args,
|
||||
"return_type": self._get_type_string(signature.return_annotation),
|
||||
"return_annotation": signature.return_annotation, # 保留原始类型注解,用于TypedDict等特殊处理
|
||||
"is_async": inspect.iscoroutinefunction(method),
|
||||
}
|
||||
|
||||
def _get_return_type_from_method(self, method) -> str:
|
||||
def _get_return_type_from_method(self, method) -> Union[str, Tuple[str, Any]]:
|
||||
"""从方法中获取返回类型"""
|
||||
signature = inspect.signature(method)
|
||||
return self._get_type_string(signature.return_annotation)
|
||||
|
||||
def _get_type_string(self, annotation) -> Union[str, Tuple[str, Any]]:
|
||||
"""将类型注解转换为Class Library中可搜索的类名"""
|
||||
"""将类型注解转换为类型字符串。
|
||||
|
||||
非内建类返回 ``module:ClassName`` 全路径(如
|
||||
``"unilabos.registry.placeholder_type:ResourceSlot"``),
|
||||
避免短名冲突;内建类型直接返回短名(如 ``"str"``、``"int"``)。
|
||||
"""
|
||||
if annotation == inspect.Parameter.empty:
|
||||
return "Any" # 如果没有注解,返回Any
|
||||
return "Any"
|
||||
if annotation is None:
|
||||
return "None" # 明确的None类型
|
||||
return "None"
|
||||
if hasattr(annotation, "__origin__"):
|
||||
# 处理typing模块的类型
|
||||
origin = annotation.__origin__
|
||||
if origin in (list, set, tuple):
|
||||
if hasattr(annotation, "__args__") and annotation.__args__:
|
||||
@@ -426,126 +352,26 @@ class ImportManager:
|
||||
return "dict"
|
||||
elif origin is Optional:
|
||||
return "Unknown"
|
||||
return f"Unknown"
|
||||
return "Unknown"
|
||||
annotation_str = str(annotation)
|
||||
# 处理typing模块的复杂类型
|
||||
if "typing." in annotation_str:
|
||||
# 简化typing类型显示
|
||||
return (
|
||||
annotation_str.replace("typing.", "")
|
||||
if getattr(annotation, "_name", None) is None
|
||||
else annotation._name.lower()
|
||||
)
|
||||
# 如果是类型对象
|
||||
if hasattr(annotation, "__name__"):
|
||||
# 如果是内置类型
|
||||
if annotation.__module__ == "builtins":
|
||||
return annotation.__name__
|
||||
else:
|
||||
# 如果是自定义类,返回完整路径
|
||||
return f"{annotation.__module__}:{annotation.__name__}"
|
||||
# 如果是typing模块的类型
|
||||
module = getattr(annotation, "__module__", None)
|
||||
if module and module != "builtins":
|
||||
return f"{module}:{annotation.__name__}"
|
||||
return annotation.__name__
|
||||
elif hasattr(annotation, "_name"):
|
||||
return annotation._name
|
||||
# 如果是字符串形式的类型注解
|
||||
elif isinstance(annotation, str):
|
||||
return annotation
|
||||
else:
|
||||
return annotation_str
|
||||
|
||||
def _is_property_method(self, node: ast.FunctionDef) -> bool:
|
||||
"""检查是否是@property装饰的方法"""
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Name) and decorator.id == "property":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_setter_method(self, node: ast.FunctionDef) -> bool:
|
||||
"""检查是否是@xxx.setter装饰的方法"""
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Attribute) and decorator.attr == "setter":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_not_action_method(self, node: ast.FunctionDef) -> bool:
|
||||
"""检查是否是@not_action装饰的方法"""
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Name) and decorator.id == "not_action":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_always_free_method(self, node: ast.FunctionDef) -> bool:
|
||||
"""检查是否是@always_free装饰的方法"""
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Name) and decorator.id == "always_free":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_property_name_from_setter(self, node: ast.FunctionDef) -> str:
|
||||
"""从setter装饰器中获取属性名"""
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Attribute) and decorator.attr == "setter":
|
||||
if isinstance(decorator.value, ast.Name):
|
||||
return decorator.value.id
|
||||
return node.name
|
||||
|
||||
def get_class_info_static(self, module_class_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
静态分析获取类的方法信息,不需要实际导入模块
|
||||
|
||||
Args:
|
||||
module_class_path: 格式为 "module.path:ClassName" 的字符串
|
||||
|
||||
Returns:
|
||||
包含类方法信息的字典
|
||||
"""
|
||||
try:
|
||||
if ":" not in module_class_path:
|
||||
raise ValueError("module_class_path必须是 'module.path:ClassName' 格式")
|
||||
|
||||
module_path, class_name = module_class_path.rsplit(":", 1)
|
||||
|
||||
# 将模块路径转换为文件路径
|
||||
file_path = self._module_path_to_file_path(module_path)
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
logger.warning(f"找不到模块文件: {module_path} -> {file_path}")
|
||||
return {}
|
||||
|
||||
# 解析源码
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# 查找目标类
|
||||
class_node = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
class_node = node
|
||||
break
|
||||
|
||||
if not class_node:
|
||||
logger.warning(f"在模块 {module_path} 中找不到类 {class_name}")
|
||||
return {}
|
||||
|
||||
# 分析类的方法
|
||||
methods_info = {}
|
||||
for node in class_node.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
method_info = self._analyze_method_node(node)
|
||||
methods_info[node.name] = method_info
|
||||
|
||||
return {
|
||||
"class_name": class_name,
|
||||
"module_path": module_path,
|
||||
"file_path": file_path,
|
||||
"methods": methods_info,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"静态分析类 {module_class_path} 时出错: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _module_path_to_file_path(self, module_path: str) -> Optional[str]:
|
||||
for path in sys.path:
|
||||
potential_path = Path(path) / module_path.replace(".", "/")
|
||||
@@ -560,222 +386,6 @@ class ImportManager:
|
||||
|
||||
return None
|
||||
|
||||
def _analyze_method_node(self, node: ast.FunctionDef) -> Dict[str, Any]:
|
||||
"""分析方法节点,提取参数和返回类型信息"""
|
||||
method_info = {
|
||||
"name": node.name,
|
||||
"args": [],
|
||||
"return_type": None,
|
||||
"is_async": isinstance(node, ast.AsyncFunctionDef),
|
||||
}
|
||||
# 获取默认值列表
|
||||
defaults = node.args.defaults
|
||||
num_defaults = len(defaults)
|
||||
|
||||
# 计算必需参数数量
|
||||
total_args = len(node.args.args)
|
||||
num_required = total_args - num_defaults
|
||||
|
||||
# 提取参数信息
|
||||
for i, arg in enumerate(node.args.args):
|
||||
if arg.arg == "self":
|
||||
continue
|
||||
# 跳过 sample_uuids 参数(由系统自动注入)
|
||||
if arg.arg == PARAM_SAMPLE_UUIDS:
|
||||
continue
|
||||
arg_info = {
|
||||
"name": arg.arg,
|
||||
"type": None,
|
||||
"default": None,
|
||||
"required": i < num_required,
|
||||
}
|
||||
|
||||
# 提取类型注解
|
||||
if arg.annotation:
|
||||
arg_info["type"] = ast.unparse(arg.annotation) if hasattr(ast, "unparse") else str(arg.annotation)
|
||||
|
||||
# 提取默认值并推断类型
|
||||
if i >= num_required:
|
||||
default_index = i - num_required
|
||||
if default_index < len(defaults):
|
||||
default_value: Constant = defaults[default_index] # type: ignore
|
||||
assert isinstance(default_value, Constant), "暂不支持对非常量类型进行推断,可反馈开源仓库"
|
||||
arg_info["default"] = default_value.value
|
||||
# 如果没有类型注解,尝试从默认值推断类型
|
||||
if not arg_info["type"]:
|
||||
arg_info["type"] = self._get_type_string(type(arg_info["default"]))
|
||||
method_info["args"].append(arg_info)
|
||||
|
||||
# 提取返回类型
|
||||
if node.returns:
|
||||
method_info["return_type"] = ast.unparse(node.returns) if hasattr(ast, "unparse") else str(node.returns)
|
||||
|
||||
return method_info
|
||||
|
||||
def _infer_type_from_default(self, node: ast.AST) -> Optional[str]:
|
||||
"""从默认值推断参数类型"""
|
||||
if isinstance(node, ast.Constant):
|
||||
value = node.value
|
||||
if isinstance(value, bool):
|
||||
return "bool"
|
||||
elif isinstance(value, int):
|
||||
return "int"
|
||||
elif isinstance(value, float):
|
||||
return "float"
|
||||
elif isinstance(value, str):
|
||||
return "str"
|
||||
elif value is None:
|
||||
return "Optional[Any]"
|
||||
elif isinstance(node, ast.List):
|
||||
return "List"
|
||||
elif isinstance(node, ast.Dict):
|
||||
return "Dict"
|
||||
elif isinstance(node, ast.Tuple):
|
||||
return "Tuple"
|
||||
elif isinstance(node, ast.Set):
|
||||
return "Set"
|
||||
elif isinstance(node, ast.Name):
|
||||
# 常见的默认值模式
|
||||
if node.id in ["None"]:
|
||||
return "Optional[Any]"
|
||||
elif node.id in ["True", "False"]:
|
||||
return "bool"
|
||||
|
||||
return None
|
||||
|
||||
def _infer_types_from_docstring(self, method_info: Dict[str, Any]) -> None:
|
||||
"""从docstring中推断参数类型"""
|
||||
docstring = method_info.get("docstring", "")
|
||||
if not docstring:
|
||||
return
|
||||
|
||||
lines = docstring.split("\n")
|
||||
in_args_section = False
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 检测Args或Arguments段落
|
||||
if line.lower().startswith(("args:", "arguments:")):
|
||||
in_args_section = True
|
||||
continue
|
||||
elif line.startswith(("returns:", "return:", "yields:", "raises:")):
|
||||
in_args_section = False
|
||||
continue
|
||||
elif not line or not in_args_section:
|
||||
continue
|
||||
|
||||
# 解析参数行,格式通常是: param_name (type): description 或 param_name: description
|
||||
if ":" in line:
|
||||
parts = line.split(":", 1)
|
||||
param_part = parts[0].strip()
|
||||
|
||||
# 提取参数名和类型
|
||||
param_name = None
|
||||
param_type = None
|
||||
|
||||
if "(" in param_part and ")" in param_part:
|
||||
# 格式: param_name (type)
|
||||
param_name = param_part.split("(")[0].strip()
|
||||
type_part = param_part.split("(")[1].split(")")[0].strip()
|
||||
param_type = type_part
|
||||
else:
|
||||
# 格式: param_name
|
||||
param_name = param_part
|
||||
|
||||
# 更新对应参数的类型信息
|
||||
if param_name:
|
||||
for arg_info in method_info["args"]:
|
||||
if arg_info["name"] == param_name and not arg_info["type"]:
|
||||
if param_type:
|
||||
arg_info["inferred_type"] = param_type
|
||||
elif not arg_info["inferred_type"]:
|
||||
# 从描述中推断类型
|
||||
description = parts[1].strip().lower()
|
||||
if any(word in description for word in ["path", "file", "directory", "filename"]):
|
||||
arg_info["inferred_type"] = "str"
|
||||
elif any(
|
||||
word in description for word in ["port", "number", "count", "size", "length"]
|
||||
):
|
||||
arg_info["inferred_type"] = "int"
|
||||
elif any(
|
||||
word in description for word in ["rate", "ratio", "percentage", "temperature"]
|
||||
):
|
||||
arg_info["inferred_type"] = "float"
|
||||
elif any(word in description for word in ["flag", "enable", "disable", "option"]):
|
||||
arg_info["inferred_type"] = "bool"
|
||||
|
||||
def get_registry_class_info(self, module_class_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取适用于注册表的类信息,包含完整的类型推断
|
||||
|
||||
Args:
|
||||
module_class_path: 格式为 "module.path:ClassName" 的字符串
|
||||
|
||||
Returns:
|
||||
适用于注册表的类信息字典
|
||||
"""
|
||||
class_info = self.get_class_info_static(module_class_path)
|
||||
if not class_info:
|
||||
return {}
|
||||
|
||||
registry_info = {
|
||||
"class_name": class_info["class_name"],
|
||||
"module_path": class_info["module_path"],
|
||||
"file_path": class_info["file_path"],
|
||||
"methods": {},
|
||||
"properties": [],
|
||||
"init_params": {},
|
||||
"action_methods": {},
|
||||
}
|
||||
|
||||
for method_name, method_info in class_info["methods"].items():
|
||||
# 分类处理不同类型的方法
|
||||
if method_info["is_property"]:
|
||||
registry_info["properties"].append(
|
||||
{
|
||||
"name": method_name,
|
||||
"return_type": method_info.get("return_type"),
|
||||
"docstring": method_info.get("docstring"),
|
||||
}
|
||||
)
|
||||
elif method_name == "__init__":
|
||||
# 处理初始化参数
|
||||
init_params = {}
|
||||
for arg in method_info["args"]:
|
||||
if arg["name"] != "self":
|
||||
param_info = {
|
||||
"name": arg["name"],
|
||||
"type": arg.get("type") or arg.get("inferred_type"),
|
||||
"required": arg.get("is_required", True),
|
||||
"default": arg.get("default"),
|
||||
}
|
||||
init_params[arg["name"]] = param_info
|
||||
registry_info["init_params"] = init_params
|
||||
elif not method_name.startswith("_"):
|
||||
# 处理公共方法(可能的action方法)
|
||||
action_info = {
|
||||
"name": method_name,
|
||||
"params": {},
|
||||
"return_type": method_info.get("return_type"),
|
||||
"docstring": method_info.get("docstring"),
|
||||
"num_required": method_info.get("num_required", 0) - 1, # 减去self
|
||||
"num_defaults": method_info.get("num_defaults", 0),
|
||||
}
|
||||
|
||||
for arg in method_info["args"]:
|
||||
if arg["name"] != "self":
|
||||
param_info = {
|
||||
"name": arg["name"],
|
||||
"type": arg.get("type") or arg.get("inferred_type"),
|
||||
"required": arg.get("is_required", True),
|
||||
"default": arg.get("default"),
|
||||
}
|
||||
action_info["params"][arg["name"]] = param_info
|
||||
|
||||
registry_info["action_methods"][method_name] = action_info
|
||||
|
||||
return registry_info
|
||||
|
||||
|
||||
# 全局实例,便于直接使用
|
||||
@@ -803,16 +413,6 @@ def init_from_list(module_list: List[str]) -> None:
|
||||
default_manager = ImportManager(module_list)
|
||||
|
||||
|
||||
def get_class_info_static(module_class_path: str) -> Dict[str, Any]:
|
||||
"""静态分析获取类信息的便捷函数"""
|
||||
return default_manager.get_class_info_static(module_class_path)
|
||||
|
||||
|
||||
def get_registry_class_info(module_class_path: str) -> Dict[str, Any]:
|
||||
"""获取适用于注册表的类信息的便捷函数"""
|
||||
return default_manager.get_registry_class_info(module_class_path)
|
||||
|
||||
|
||||
def get_enhanced_class_info(module_path: str, use_dynamic: bool = True) -> Dict[str, Any]:
|
||||
def get_enhanced_class_info(module_path: str, **kwargs) -> Dict[str, Any]:
|
||||
"""获取增强的类信息的便捷函数"""
|
||||
return default_manager.get_enhanced_class_info(module_path, use_dynamic)
|
||||
return default_manager.get_enhanced_class_info(module_path, **kwargs)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
networkx
|
||||
typing_extensions
|
||||
websockets
|
||||
msgcenterpy>=0.1.5
|
||||
msgcenterpy>=0.1.8
|
||||
orjson>=3.11
|
||||
opentrons_shared_data
|
||||
pint
|
||||
fastapi
|
||||
|
||||
@@ -1,4 +1,39 @@
|
||||
import json
|
||||
|
||||
from unilabos.utils.type_check import TypeEncoder, json_default
|
||||
|
||||
try:
|
||||
import orjson
|
||||
|
||||
def fast_dumps(obj, **kwargs) -> bytes:
|
||||
"""JSON 序列化为 bytes,优先使用 orjson。"""
|
||||
return orjson.dumps(obj, option=orjson.OPT_NON_STR_KEYS, default=json_default)
|
||||
|
||||
def fast_dumps_pretty(obj, **kwargs) -> bytes:
|
||||
"""JSON 序列化为 bytes(带缩进),优先使用 orjson。"""
|
||||
return orjson.dumps(
|
||||
obj,
|
||||
option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2,
|
||||
default=json_default,
|
||||
)
|
||||
|
||||
def normalize_json(info: dict) -> dict:
|
||||
"""经 JSON 序列化/反序列化一轮来清理非标准类型。"""
|
||||
return orjson.loads(orjson.dumps(info, default=json_default))
|
||||
|
||||
except ImportError:
|
||||
|
||||
def fast_dumps(obj, **kwargs) -> bytes: # type: ignore[misc]
|
||||
return json.dumps(obj, ensure_ascii=False, cls=TypeEncoder).encode("utf-8")
|
||||
|
||||
def fast_dumps_pretty(obj, **kwargs) -> bytes: # type: ignore[misc]
|
||||
return json.dumps(obj, indent=2, ensure_ascii=False, cls=TypeEncoder).encode("utf-8")
|
||||
|
||||
def normalize_json(info: dict) -> dict: # type: ignore[misc]
|
||||
return json.loads(json.dumps(info, ensure_ascii=False, cls=TypeEncoder))
|
||||
|
||||
|
||||
# 辅助函数:将UUID数组转换为字符串
|
||||
def uuid_to_str(uuid_array) -> str:
|
||||
"""将UUID字节数组转换为十六进制字符串"""
|
||||
return "".join(format(byte, "02x") for byte in uuid_array)
|
||||
return "".join(format(byte, "02x") for byte in uuid_array)
|
||||
|
||||
@@ -15,14 +15,21 @@ def get_type_class(type_hint):
|
||||
return final_type
|
||||
|
||||
|
||||
def json_default(obj):
|
||||
"""将 type 对象序列化为类名,其余 fallback 到 str()。"""
|
||||
if isinstance(obj, type):
|
||||
return str(obj)[8:-2]
|
||||
return str(obj)
|
||||
|
||||
|
||||
class TypeEncoder(json.JSONEncoder):
|
||||
"""自定义JSON编码器处理特殊类型"""
|
||||
|
||||
def default(self, obj):
|
||||
# 优先处理类型对象
|
||||
if isinstance(obj, type):
|
||||
return str(obj)[8:-2]
|
||||
return super().default(obj)
|
||||
try:
|
||||
return json_default(obj)
|
||||
except Exception:
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class NoAliasDumper(yaml.SafeDumper):
|
||||
@@ -43,13 +50,10 @@ class ResultInfoEncoder(json.JSONEncoder):
|
||||
"""专门用于处理任务执行结果信息的JSON编码器"""
|
||||
|
||||
def default(self, obj):
|
||||
# 优先处理类型对象
|
||||
if isinstance(obj, type):
|
||||
return str(obj)[8:-2]
|
||||
return json_default(obj)
|
||||
|
||||
# 对于无法序列化的对象,统一转换为字符串
|
||||
try:
|
||||
# 尝试调用 __dict__ 或者其他序列化方法
|
||||
if hasattr(obj, "__dict__"):
|
||||
return obj.__dict__
|
||||
elif hasattr(obj, "_asdict"): # namedtuple
|
||||
@@ -59,10 +63,8 @@ class ResultInfoEncoder(json.JSONEncoder):
|
||||
elif hasattr(obj, "dict"):
|
||||
return obj.dict()
|
||||
else:
|
||||
# 如果都不行,转换为字符串
|
||||
return str(obj)
|
||||
except Exception:
|
||||
# 如果转换失败,直接返回字符串表示
|
||||
return str(obj)
|
||||
|
||||
|
||||
@@ -78,11 +80,12 @@ def get_result_info_str(error: str, suc: bool, return_value=None) -> str:
|
||||
Returns:
|
||||
JSON字符串格式的结果信息
|
||||
"""
|
||||
samples = None
|
||||
if isinstance(return_value, dict):
|
||||
if "samples" in return_value:
|
||||
samples = return_value.pop("samples")
|
||||
result_info = {"error": error, "suc": suc, "return_value": return_value, "samples": samples}
|
||||
# 请在返回的字典中使用 unilabos_samples进行返回
|
||||
# samples = None
|
||||
# if isinstance(return_value, dict):
|
||||
# if "samples" in return_value and type(return_value["samples"]) in [list, tuple] and type(return_value["samples"][0]) == dict:
|
||||
# samples = return_value.pop("samples")
|
||||
result_info = {"error": error, "suc": suc, "return_value": return_value}
|
||||
|
||||
return json.dumps(result_info, ensure_ascii=False, cls=ResultInfoEncoder)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user