mirror of
https://github.com/deepmodeling/Uni-Lab-OS
synced 2026-03-25 03:29:14 +00:00
497 lines
20 KiB
Python
497 lines
20 KiB
Python
"""model_upload.py 单元测试(upload_device_model / download_model_from_oss / XOR 加解密)"""
|
||
|
||
import unittest
|
||
import tempfile
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||
|
||
from unilabos.app.model_upload import (
|
||
upload_device_model,
|
||
download_model_from_oss,
|
||
_MODEL_EXTENSIONS,
|
||
_MESH_ENCRYPT_EXTENSIONS,
|
||
_xor_transform,
|
||
)
|
||
|
||
|
||
class TestUploadDeviceModel(unittest.TestCase):
|
||
"""测试本地模型文件上传到 OSS"""
|
||
|
||
def setUp(self):
|
||
self.tmp_dir = tempfile.mkdtemp()
|
||
self.mock_client = MagicMock()
|
||
|
||
def _create_model_files(self, subdir: str, filenames: list[str]):
|
||
"""在临时目录中创建设备模型文件"""
|
||
model_dir = Path(self.tmp_dir) / "devices" / subdir
|
||
model_dir.mkdir(parents=True, exist_ok=True)
|
||
for name in filenames:
|
||
p = model_dir / name
|
||
p.parent.mkdir(parents=True, exist_ok=True)
|
||
p.write_text("dummy content")
|
||
return model_dir
|
||
|
||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||
def test_upload_success(self, mock_base):
|
||
"""正常上传流程"""
|
||
mock_base.__truediv__ = lambda self, x: Path(self.tmp_dir) / x
|
||
# 直接 patch _MESH_BASE_DIR 为 Path(tmp_dir)
|
||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||
self._create_model_files("arm_slider", ["macro_device.xacro", "meshes/link1.stl"])
|
||
|
||
self.mock_client.get_model_upload_urls.return_value = {
|
||
"files": [
|
||
{"name": "macro_device.xacro", "upload_url": "https://oss.example.com/put1"},
|
||
{"name": "meshes/link1.stl", "upload_url": "https://oss.example.com/put2"},
|
||
]
|
||
}
|
||
self.mock_client.publish_model.return_value = {
|
||
"path": "https://oss.example.com/arm_slider/macro_device.xacro"
|
||
}
|
||
|
||
with patch("unilabos.app.model_upload._put_upload") as mock_put:
|
||
result = upload_device_model(
|
||
http_client=self.mock_client,
|
||
template_uuid="test-uuid",
|
||
mesh_name="arm_slider",
|
||
model_type="device",
|
||
version="1.0.0",
|
||
)
|
||
|
||
self.assertEqual(result, "https://oss.example.com/arm_slider/macro_device.xacro")
|
||
self.mock_client.get_model_upload_urls.assert_called_once()
|
||
self.mock_client.publish_model.assert_called_once()
|
||
|
||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||
def test_upload_dir_not_exists(self, mock_base):
|
||
"""本地目录不存在时返回 None"""
|
||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||
result = upload_device_model(
|
||
http_client=self.mock_client,
|
||
template_uuid="test-uuid",
|
||
mesh_name="nonexistent",
|
||
model_type="device",
|
||
)
|
||
self.assertIsNone(result)
|
||
|
||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||
def test_upload_no_valid_files(self, mock_base):
|
||
"""目录中无有效模型文件时返回 None"""
|
||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||
model_dir = Path(self.tmp_dir) / "devices" / "empty_model"
|
||
model_dir.mkdir(parents=True, exist_ok=True)
|
||
(model_dir / "readme.txt").write_text("not a model")
|
||
|
||
result = upload_device_model(
|
||
http_client=self.mock_client,
|
||
template_uuid="test-uuid",
|
||
mesh_name="empty_model",
|
||
model_type="device",
|
||
)
|
||
self.assertIsNone(result)
|
||
|
||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||
def test_upload_urls_failure(self, mock_base):
|
||
"""获取上传 URL 失败时返回 None"""
|
||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||
self._create_model_files("arm", ["device.xacro"])
|
||
self.mock_client.get_model_upload_urls.return_value = None
|
||
|
||
result = upload_device_model(
|
||
http_client=self.mock_client,
|
||
template_uuid="test-uuid",
|
||
mesh_name="arm",
|
||
model_type="device",
|
||
)
|
||
self.assertIsNone(result)
|
||
|
||
|
||
class TestDownloadModelFromOss(unittest.TestCase):
|
||
"""测试从 OSS 下载模型文件到本地"""
|
||
|
||
def setUp(self):
|
||
self.tmp_dir = tempfile.mkdtemp()
|
||
|
||
def test_skip_no_mesh_name(self):
|
||
"""缺少 mesh 名称时跳过"""
|
||
result = download_model_from_oss({"type": "device", "path": "https://x.com/a.xacro"})
|
||
self.assertFalse(result)
|
||
|
||
def test_skip_no_oss_path(self):
|
||
"""缺少 OSS path 时跳过"""
|
||
result = download_model_from_oss({"mesh": "arm", "type": "device"})
|
||
self.assertFalse(result)
|
||
|
||
def test_skip_local_path(self):
|
||
"""非 https:// 路径时跳过"""
|
||
result = download_model_from_oss({
|
||
"mesh": "arm",
|
||
"type": "device",
|
||
"path": "file:///local/model.xacro",
|
||
})
|
||
self.assertFalse(result)
|
||
|
||
def test_already_exists(self):
|
||
"""本地已有文件时跳过下载"""
|
||
device_dir = Path(self.tmp_dir) / "devices" / "arm"
|
||
device_dir.mkdir(parents=True, exist_ok=True)
|
||
(device_dir / "model.xacro").write_text("existing")
|
||
|
||
result = download_model_from_oss(
|
||
{"mesh": "arm", "type": "device", "path": "https://oss.example.com/model.xacro"},
|
||
mesh_base_dir=Path(self.tmp_dir),
|
||
)
|
||
self.assertTrue(result)
|
||
|
||
@patch("unilabos.app.model_upload._download_file")
|
||
def test_download_device(self, mock_download):
|
||
"""下载 device 模型到 devices/ 目录"""
|
||
result = download_model_from_oss(
|
||
{"mesh": "new_arm", "type": "device", "path": "https://oss.example.com/new_arm/macro_device.xacro"},
|
||
mesh_base_dir=Path(self.tmp_dir),
|
||
)
|
||
self.assertTrue(result)
|
||
mock_download.assert_called_once()
|
||
call_args = mock_download.call_args
|
||
self.assertIn("macro_device.xacro", str(call_args[0][1]))
|
||
|
||
@patch("unilabos.app.model_upload._download_file")
|
||
def test_download_resource(self, mock_download):
|
||
"""下载 resource 模型到 resources/ 目录"""
|
||
result = download_model_from_oss(
|
||
{
|
||
"mesh": "plate_96/meshes/plate_96.stl",
|
||
"type": "resource",
|
||
"path": "https://oss.example.com/plate_96/modal.xacro",
|
||
},
|
||
mesh_base_dir=Path(self.tmp_dir),
|
||
)
|
||
self.assertTrue(result)
|
||
target_dir = Path(self.tmp_dir) / "resources" / "plate_96"
|
||
self.assertTrue(target_dir.exists())
|
||
|
||
@patch("unilabos.app.model_upload._download_file")
|
||
def test_download_with_children_mesh(self, mock_download):
|
||
"""下载包含 children_mesh 的模型"""
|
||
result = download_model_from_oss(
|
||
{
|
||
"mesh": "tip_rack",
|
||
"type": "device",
|
||
"path": "https://oss.example.com/tip_rack/model.xacro",
|
||
"children_mesh": {
|
||
"path": "https://oss.example.com/tip_rack/meshes/tip.stl",
|
||
"format": "stl",
|
||
},
|
||
},
|
||
mesh_base_dir=Path(self.tmp_dir),
|
||
)
|
||
self.assertTrue(result)
|
||
# 应调用两次:入口文件 + children_mesh
|
||
self.assertEqual(mock_download.call_count, 2)
|
||
|
||
@patch("unilabos.app.model_upload._download_file", side_effect=Exception("network error"))
|
||
def test_download_failure_graceful(self, mock_download):
|
||
"""下载失败时返回 False(不抛异常)"""
|
||
result = download_model_from_oss(
|
||
{"mesh": "broken", "type": "device", "path": "https://oss.example.com/broken.xacro"},
|
||
mesh_base_dir=Path(self.tmp_dir),
|
||
)
|
||
self.assertFalse(result)
|
||
|
||
|
||
class TestModelExtensions(unittest.TestCase):
|
||
"""测试支持的模型文件后缀集合"""
|
||
|
||
def test_standard_extensions(self):
|
||
"""确认标准 3D 格式在支持列表中"""
|
||
expected = {".stl", ".gltf", ".glb", ".xacro", ".urdf", ".obj", ".dae"}
|
||
for ext in expected:
|
||
self.assertIn(ext, _MODEL_EXTENSIONS, f"{ext} should be supported")
|
||
|
||
def test_non_model_excluded(self):
|
||
"""非模型文件后缀不在列表中"""
|
||
excluded = {".txt", ".json", ".py", ".png", ".jpg"}
|
||
for ext in excluded:
|
||
self.assertNotIn(ext, _MODEL_EXTENSIONS, f"{ext} should not be supported")
|
||
|
||
|
||
class TestXorTransform(unittest.TestCase):
|
||
"""XOR 加密/解密核心函数测试。"""
|
||
|
||
def test_roundtrip_symmetry(self):
|
||
"""XOR 加密后再解密恢复原始数据(对称性)。"""
|
||
original = b"Hello, this is a test model file content."
|
||
encrypted = _xor_transform(original)
|
||
self.assertNotEqual(encrypted, original)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, original)
|
||
|
||
def test_empty_data(self):
|
||
"""空数据加密后仍为空。"""
|
||
result = _xor_transform(b"")
|
||
self.assertEqual(result, b"")
|
||
|
||
def test_single_byte(self):
|
||
"""单字节数据正确加解密。"""
|
||
original = b"\xff"
|
||
encrypted = _xor_transform(original)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, original)
|
||
|
||
def test_data_longer_than_key(self):
|
||
"""超过密钥长度(32 字节)的数据正确循环 XOR。"""
|
||
original = bytes(range(256)) * 2 # 512 字节
|
||
encrypted = _xor_transform(original)
|
||
self.assertNotEqual(encrypted, original)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, original)
|
||
|
||
def test_data_exactly_key_length(self):
|
||
"""恰好 32 字节(密钥长度)的数据正确处理。"""
|
||
original = bytes(range(32))
|
||
encrypted = _xor_transform(original)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, original)
|
||
|
||
def test_all_zeros_produces_key(self):
|
||
"""全零数据 XOR 后结果应为密钥本身。"""
|
||
zeros = b"\x00" * 32
|
||
result = _xor_transform(zeros)
|
||
key = os.environ.get(
|
||
"UNILAB_MESH_XOR_KEY", "unilab3d-model-protection-key-v1"
|
||
).encode()
|
||
self.assertEqual(result, key)
|
||
|
||
def test_custom_key(self):
|
||
"""自定义密钥正确加解密。"""
|
||
custom_key = b"custom-key-12345"
|
||
original = b"test data for custom key"
|
||
encrypted = _xor_transform(original, key=custom_key)
|
||
decrypted = _xor_transform(encrypted, key=custom_key)
|
||
self.assertEqual(decrypted, original)
|
||
|
||
def test_different_keys_produce_different_results(self):
|
||
"""不同密钥产生不同加密结果。"""
|
||
data = b"same data"
|
||
key1 = b"key-one-is-here!"
|
||
key2 = b"key-two-is-here!"
|
||
self.assertNotEqual(_xor_transform(data, key1), _xor_transform(data, key2))
|
||
|
||
def test_binary_stl_header(self):
|
||
"""二进制内容(模拟 STL 文件头)正确加解密。"""
|
||
stl_header = b"\x00" * 80 + b"\x03\x00\x00\x00"
|
||
encrypted = _xor_transform(stl_header)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, stl_header)
|
||
|
||
def test_large_data_roundtrip(self):
|
||
"""大数据(1MB)加解密正确性。"""
|
||
original = os.urandom(1024 * 1024)
|
||
encrypted = _xor_transform(original)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, original)
|
||
|
||
def test_consistency_with_frontend_key(self):
|
||
"""验证 Python 端与前端使用相同的默认密钥。"""
|
||
frontend_key = b"unilab3d-model-protection-key-v1"
|
||
data = b"cross-platform test data"
|
||
encrypted = _xor_transform(data, key=frontend_key)
|
||
# 用默认密钥解密(应一致)
|
||
decrypted = _xor_transform(encrypted)
|
||
self.assertEqual(decrypted, data)
|
||
|
||
|
||
class TestEncryptExtensions(unittest.TestCase):
|
||
"""加密文件扩展名配置测试。"""
|
||
|
||
def test_all_mesh_formats_in_encrypt_set(self):
|
||
"""所有 mesh 格式都在加密扩展名集合中。"""
|
||
expected = {".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb"}
|
||
self.assertEqual(_MESH_ENCRYPT_EXTENSIONS, expected)
|
||
|
||
def test_xml_formats_not_encrypted(self):
|
||
"""XACRO/URDF/YAML 文件不加密。"""
|
||
for ext in {".xacro", ".urdf", ".yaml", ".yml"}:
|
||
self.assertNotIn(ext, _MESH_ENCRYPT_EXTENSIONS)
|
||
|
||
def test_encrypt_is_subset_of_model_extensions(self):
|
||
"""加密扩展名是模型扩展名的子集。"""
|
||
self.assertTrue(_MESH_ENCRYPT_EXTENSIONS.issubset(_MODEL_EXTENSIONS))
|
||
|
||
|
||
class TestPutUploadEncryption(unittest.TestCase):
|
||
"""_put_upload 中的条件加密测试。"""
|
||
|
||
@patch("unilabos.app.model_upload.requests.put")
|
||
def test_stl_file_encrypted_before_upload(self, mock_put):
|
||
"""STL 文件上传前自动 XOR 加密。"""
|
||
from unilabos.app.model_upload import _put_upload
|
||
|
||
original_data = b"solid test\nfacet normal 0 0 1\n"
|
||
with tempfile.NamedTemporaryFile(suffix=".stl", delete=False) as f:
|
||
f.write(original_data)
|
||
f.flush()
|
||
tmp_path = Path(f.name)
|
||
|
||
try:
|
||
mock_put.return_value = MagicMock(status_code=200)
|
||
mock_put.return_value.raise_for_status = MagicMock()
|
||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||
|
||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||
self.assertIsNotNone(uploaded_data)
|
||
self.assertNotEqual(uploaded_data, original_data)
|
||
# 解密后应恢复原始数据
|
||
self.assertEqual(_xor_transform(uploaded_data), original_data)
|
||
finally:
|
||
tmp_path.unlink(missing_ok=True)
|
||
|
||
@patch("unilabos.app.model_upload.requests.put")
|
||
def test_xacro_file_not_encrypted(self, mock_put):
|
||
"""XACRO 文件上传时不加密。"""
|
||
from unilabos.app.model_upload import _put_upload
|
||
|
||
original_data = b'<?xml version="1.0"?><robot></robot>'
|
||
with tempfile.NamedTemporaryFile(suffix=".xacro", delete=False) as f:
|
||
f.write(original_data)
|
||
f.flush()
|
||
tmp_path = Path(f.name)
|
||
|
||
try:
|
||
mock_put.return_value = MagicMock(status_code=200)
|
||
mock_put.return_value.raise_for_status = MagicMock()
|
||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||
|
||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||
self.assertEqual(uploaded_data, original_data)
|
||
finally:
|
||
tmp_path.unlink(missing_ok=True)
|
||
|
||
@patch("unilabos.app.model_upload.requests.put")
|
||
def test_all_mesh_formats_encrypted(self, mock_put):
|
||
"""所有 mesh 格式上传前都加密。"""
|
||
from unilabos.app.model_upload import _put_upload
|
||
|
||
original_data = b"test mesh binary data content"
|
||
for ext in [".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb"]:
|
||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
|
||
f.write(original_data)
|
||
f.flush()
|
||
tmp_path = Path(f.name)
|
||
try:
|
||
mock_put.reset_mock()
|
||
mock_put.return_value = MagicMock(status_code=200)
|
||
mock_put.return_value.raise_for_status = MagicMock()
|
||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||
|
||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||
self.assertNotEqual(uploaded_data, original_data, f"{ext} 文件应被加密")
|
||
finally:
|
||
tmp_path.unlink(missing_ok=True)
|
||
|
||
@patch("unilabos.app.model_upload.requests.put")
|
||
def test_uppercase_extension_encrypted(self, mock_put):
|
||
"""大写扩展名 .STL 也被加密(大小写不敏感)。"""
|
||
from unilabos.app.model_upload import _put_upload
|
||
|
||
original_data = b"uppercase ext test"
|
||
with tempfile.NamedTemporaryFile(suffix=".STL", delete=False) as f:
|
||
f.write(original_data)
|
||
f.flush()
|
||
tmp_path = Path(f.name)
|
||
try:
|
||
mock_put.return_value = MagicMock(status_code=200)
|
||
mock_put.return_value.raise_for_status = MagicMock()
|
||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||
|
||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||
self.assertNotEqual(uploaded_data, original_data)
|
||
finally:
|
||
tmp_path.unlink(missing_ok=True)
|
||
|
||
|
||
class TestDownloadFileDecryption(unittest.TestCase):
|
||
"""_download_file 中的条件解密测试。"""
|
||
|
||
@patch("unilabos.app.model_upload.requests.get")
|
||
def test_mesh_file_decrypted_on_download(self, mock_get):
|
||
"""下载的 mesh 文件自动 XOR 解密后存本地。"""
|
||
from unilabos.app.model_upload import _download_file
|
||
|
||
original_data = b"original stl content here"
|
||
encrypted_data = _xor_transform(original_data)
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.content = encrypted_data
|
||
mock_response.raise_for_status = MagicMock()
|
||
mock_get.return_value = mock_response
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
local_path = Path(tmpdir) / "model.stl"
|
||
_download_file("https://oss.example.com/model.stl", local_path)
|
||
self.assertEqual(local_path.read_bytes(), original_data)
|
||
|
||
@patch("unilabos.app.model_upload.requests.get")
|
||
def test_xacro_file_not_decrypted(self, mock_get):
|
||
"""下载的 XACRO 文件不做解密处理。"""
|
||
from unilabos.app.model_upload import _download_file
|
||
|
||
xml_data = b'<?xml version="1.0"?><robot></robot>'
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.content = xml_data
|
||
mock_response.raise_for_status = MagicMock()
|
||
mock_get.return_value = mock_response
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
local_path = Path(tmpdir) / "macro.xacro"
|
||
_download_file("https://oss.example.com/macro.xacro", local_path)
|
||
self.assertEqual(local_path.read_bytes(), xml_data)
|
||
|
||
@patch("unilabos.app.model_upload.requests.get")
|
||
def test_upload_download_roundtrip(self, mock_get):
|
||
"""上传加密 → 下载解密的完整 round-trip。"""
|
||
from unilabos.app.model_upload import _download_file
|
||
|
||
original_data = b"binary stl mesh \x00\xff\x80 special bytes"
|
||
encrypted_data = _xor_transform(original_data)
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.content = encrypted_data
|
||
mock_response.raise_for_status = MagicMock()
|
||
mock_get.return_value = mock_response
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
local_path = Path(tmpdir) / "mesh.stl"
|
||
_download_file("https://oss.example.com/mesh.stl", local_path)
|
||
self.assertEqual(local_path.read_bytes(), original_data)
|
||
|
||
@patch("unilabos.app.model_upload.requests.get")
|
||
def test_all_mesh_formats_decrypted(self, mock_get):
|
||
"""所有 mesh 格式下载后都解密。"""
|
||
from unilabos.app.model_upload import _download_file
|
||
|
||
original_data = b"mesh content for roundtrip"
|
||
encrypted_data = _xor_transform(original_data)
|
||
|
||
for ext in [".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb"]:
|
||
mock_response = MagicMock()
|
||
mock_response.content = encrypted_data
|
||
mock_response.raise_for_status = MagicMock()
|
||
mock_get.return_value = mock_response
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
local_path = Path(tmpdir) / f"model{ext}"
|
||
_download_file(f"https://oss.example.com/model{ext}", local_path)
|
||
self.assertEqual(
|
||
local_path.read_bytes(), original_data, f"{ext} 文件应被解密"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|