StructBERT Web系统API文档自动生成:Swagger集成与测试用例编写
StructBERT Web系统API文档自动生成:Swagger集成与测试用例编写
1. 项目背景与需求
当你成功部署了StructBERT中文语义智能匹配系统,看着它稳定运行,提供着精准的语义相似度计算和特征提取服务时,一个新的问题可能就摆在了面前:如何让其他开发同事、外部合作伙伴或者未来的自己,能够快速、准确地使用你提供的API接口?
手动编写API文档不仅耗时耗力,还容易出错。文档更新不及时,接口改了但文档没改,这种“文档滞后”的问题在项目协作中屡见不鲜。更重要的是,缺乏规范的文档,接口的测试和验证也变得困难重重。
本文将带你解决这个问题。我们将为你的StructBERT Web系统集成Swagger(OpenAPI)文档,实现API接口的自动生成与可视化,并编写配套的测试用例,打造一个既专业又易用的API服务。整个过程不需要你成为Swagger专家,跟着步骤走,一两个小时就能搞定。
2. 环境准备与Swagger集成
2.1 理解Swagger能为我们做什么
在开始动手之前,我们先简单了解一下Swagger(现在更常被称为OpenAPI)到底是什么,以及它能给我们带来什么好处。
Swagger是一套用于描述、生成、调用和可视化RESTful API的规范与工具集。把它集成到你的Flask项目中,主要能实现三个功能:
- 自动生成API文档:你只需要在代码中添加一些简单的注解(装饰器),Swagger就能自动扫描这些注解,生成结构清晰、内容完整的API文档。
- 提供交互式API测试界面:生成的文档不是静态的HTML页面,而是一个可以在浏览器中直接调用接口的交互式界面。你可以在这个界面上填写参数、发送请求、查看响应,就像使用Postman一样方便。
- 生成客户端SDK代码:如果需要,Swagger还可以根据API描述,自动生成多种编程语言(如Python、Java、JavaScript)的客户端调用代码。
对于我们的StructBERT系统来说,这意味着:
- 前端开发同事可以直接在浏览器里查看所有接口的详细说明,包括URL、方法、参数、返回值等。
- 测试同事可以直接在Swagger UI上测试接口,无需安装额外的测试工具。
- 你自己在几个月后回顾项目时,也能快速回忆起每个接口的用途和用法。
2.2 安装必要的Python包
首先,我们需要在现有的torch26虚拟环境中安装几个必要的Python包。打开终端,激活你的虚拟环境,然后执行以下命令:
# 激活你的torch26虚拟环境(根据你的环境名称调整) conda activate torch26 # 安装Flask-Swagger-UI和Flask-RESTx # Flask-Swagger-UI用于提供Swagger UI界面 # Flask-RESTx是Flask-RESTPlus的继承者,功能更强大,我们主要用它来自动生成API文档 pip install flask-swagger-ui flask-restx 安装完成后,你可以通过以下命令验证是否安装成功:
pip list | grep -E "flask-swagger-ui|flask-restx" 应该能看到类似这样的输出:
flask-restx 1.3.0 flask-swagger-ui 4.11.1 2.3 修改Flask应用代码集成Swagger
现在,我们来修改StructBERT系统的Flask应用代码。我们假设你的主应用文件是app.py(如果不是,请根据实际情况调整文件名)。
重要提示:在修改代码前,建议先备份你的原始文件。
打开app.py文件,在文件顶部添加必要的导入:
# 在原有导入的基础上,添加以下导入 from flask_restx import Api, Resource, fields, reqparse from flask_swagger_ui import get_swaggerui_blueprint 接下来,找到Flask应用初始化的地方(通常是app = Flask(__name__)这一行),在其后添加Swagger和RESTx的配置:
# 原有的Flask应用初始化 app = Flask(__name__) app.config['JSON_AS_ASCII'] = False # 确保中文正常显示 # ========== 新增:Swagger UI配置 ========== SWAGGER_URL = '/api/docs' # Swagger UI的访问路径 API_URL = '/api/swagger.json' # OpenAPI规范文件的路径 # 创建Swagger UI蓝图 swaggerui_blueprint = get_swaggerui_blueprint( SWAGGER_URL, API_URL, config={ 'app_name': "StructBERT中文语义智能匹配系统API文档" } ) # 注册Swagger UI蓝图到应用 app.register_blueprint(swaggerui_blueprint, url_prefix=SWAGGER_URL) # ========== 新增:Flask-RESTx API初始化 ========== # 创建API实例 api = Api( app, version='1.0', title='StructBERT中文语义智能匹配系统API', description='基于StructBERT孪生网络模型的中文语义相似度计算与特征提取服务', doc='/api/docs' # 注意:这里我们使用自定义的Swagger UI,所以不启用RESTx自带的文档 ) # 定义命名空间(用于组织接口) ns_similarity = api.namespace('similarity', description='语义相似度计算相关接口') ns_feature = api.namespace('feature', description='特征提取相关接口') 2.4 重构API接口并添加Swagger注解
现在我们需要重构现有的API接口,使用Flask-RESTx的Resource类和装饰器来添加Swagger注解。
首先,我们定义请求和响应的数据模型。在API初始化代码后面添加:
# ========== 新增:定义数据模型 ========== # 相似度计算请求模型 similarity_request_model = api.model('SimilarityRequest', { 'text1': fields.String(required=True, description='第一段文本', example='今天天气真好'), 'text2': fields.String(required=True, description='第二段文本', example='阳光明媚的早晨'), 'threshold_high': fields.Float(required=False, description='高相似度阈值,默认0.7', example=0.7), 'threshold_low': fields.Float(required=False, description='低相似度阈值,默认0.3', example=0.3) }) # 相似度计算响应模型 similarity_response_model = api.model('SimilarityResponse', { 'text1': fields.String(description='第一段文本'), 'text2': fields.String(description='第二段文本'), 'similarity_score': fields.Float(description='语义相似度得分(0-1)'), 'similarity_level': fields.String(description='相似度等级:high/medium/low'), 'inference_time': fields.Float(description='推理耗时(秒)'), 'status': fields.String(description='请求状态:success/error'), 'message': fields.String(description='状态信息') }) # 特征提取请求模型 feature_request_model = api.model('FeatureRequest', { 'text': fields.String(required=True, description='需要提取特征的文本', example='这是一个测试文本'), 'batch_texts': fields.List(fields.String, required=False, description='批量文本列表(用于批量提取)', example=['文本1', '文本2', '文本3']) }) # 特征提取响应模型 feature_response_model = api.model('FeatureResponse', { 'text': fields.String(description='输入文本'), 'feature_vector': fields.List(fields.Float, description='768维特征向量'), 'feature_preview': fields.List(fields.Float, description='前20维特征预览'), 'inference_time': fields.Float(description='推理耗时(秒)'), 'status': fields.String(description='请求状态:success/error'), 'message': fields.String(description='状态信息') }) # 批量特征提取响应模型 batch_feature_response_model = api.model('BatchFeatureResponse', { 'results': fields.List(fields.Nested(feature_response_model), description='批量处理结果列表'), 'total_time': fields.Float(description='总处理耗时(秒)'), 'total_texts': fields.Integer(description='处理文本总数'), 'status': fields.String(description='请求状态:success/error'), 'message': fields.String(description='状态信息') }) 接下来,我们需要重构原有的路由处理函数。找到你处理API请求的函数(可能是calculate_similarity()、extract_feature()等),将它们重构为Resource类。
在app.py中添加以下代码(请根据你的实际函数名和逻辑进行调整):
# ========== 新增:语义相似度计算接口 ========== @ns_similarity.route('/calculate') class SimilarityCalculate(Resource): @ns_similarity.expect(similarity_request_model) @ns_similarity.marshal_with(similarity_response_model) def post(self): """ 计算两段中文文本的语义相似度 使用StructBERT孪生网络模型计算文本相似度,返回0-1之间的相似度得分和等级判定。 支持自定义相似度阈值。 """ try: # 获取请求数据 data = api.payload text1 = data.get('text1', '').strip() text2 = data.get('text2', '').strip() threshold_high = data.get('threshold_high', 0.7) threshold_low = data.get('threshold_low', 0.3) # 参数校验 if not text1 or not text2: return { 'text1': text1, 'text2': text2, 'similarity_score': 0.0, 'similarity_level': 'error', 'inference_time': 0.0, 'status': 'error', 'message': '文本内容不能为空' }, 400 # 记录开始时间 start_time = time.time() # 调用你的相似度计算函数 # 注意:这里需要替换成你实际的相似度计算函数 similarity_score = calculate_similarity_score(text1, text2) # 根据阈值判定相似度等级 if similarity_score >= threshold_high: similarity_level = 'high' elif similarity_score >= threshold_low: similarity_level = 'medium' else: similarity_level = 'low' # 计算推理时间 inference_time = time.time() - start_time return { 'text1': text1, 'text2': text2, 'similarity_score': round(similarity_score, 4), 'similarity_level': similarity_level, 'inference_time': round(inference_time, 4), 'status': 'success', 'message': '相似度计算成功' } except Exception as e: return { 'text1': data.get('text1', ''), 'text2': data.get('text2', ''), 'similarity_score': 0.0, 'similarity_level': 'error', 'inference_time': 0.0, 'status': 'error', 'message': f'相似度计算失败:{str(e)}' }, 500 # ========== 新增:单文本特征提取接口 ========== @ns_feature.route('/extract') class FeatureExtract(Resource): @ns_feature.expect(feature_request_model) @ns_feature.marshal_with(feature_response_model) def post(self): """ 提取单文本的768维语义特征向量 使用StructBERT模型提取文本的语义特征,返回768维向量和前20维预览。 """ try: data = api.payload text = data.get('text', '').strip() if not text: return { 'text': '', 'feature_vector': [], 'feature_preview': [], 'inference_time': 0.0, 'status': 'error', 'message': '文本内容不能为空' }, 400 start_time = time.time() # 调用你的特征提取函数 # 注意:这里需要替换成你实际的特征提取函数 feature_vector = extract_text_feature(text) # 获取前20维作为预览 feature_preview = feature_vector[:20] if len(feature_vector) >= 20 else feature_vector inference_time = time.time() - start_time return { 'text': text, 'feature_vector': [round(float(x), 6) for x in feature_vector], 'feature_preview': [round(float(x), 6) for x in feature_preview], 'inference_time': round(inference_time, 4), 'status': 'success', 'message': '特征提取成功' } except Exception as e: return { 'text': data.get('text', ''), 'feature_vector': [], 'feature_preview': [], 'inference_time': 0.0, 'status': 'error', 'message': f'特征提取失败:{str(e)}' }, 500 # ========== 新增:批量特征提取接口 ========== @ns_feature.route('/batch-extract') class BatchFeatureExtract(Resource): @ns_feature.expect(feature_request_model) @ns_feature.marshal_with(batch_feature_response_model) def post(self): """ 批量提取多文本的768维语义特征向量 一次性处理多个文本,返回每个文本的特征向量和预览。 建议单次批量不超过100条文本。 """ try: data = api.payload batch_texts = data.get('batch_texts', []) if not batch_texts or not isinstance(batch_texts, list): return { 'results': [], 'total_time': 0.0, 'total_texts': 0, 'status': 'error', 'message': '批量文本列表不能为空且必须是列表格式' }, 400 # 过滤空文本 valid_texts = [text.strip() for text in batch_texts if text and str(text).strip()] if not valid_texts: return { 'results': [], 'total_time': 0.0, 'total_texts': 0, 'status': 'error', 'message': '没有有效的文本内容' }, 400 total_start_time = time.time() results = [] # 批量处理文本 for text in valid_texts: try: text_start_time = time.time() # 调用你的特征提取函数 feature_vector = extract_text_feature(text) feature_preview = feature_vector[:20] if len(feature_vector) >= 20 else feature_vector inference_time = time.time() - text_start_time results.append({ 'text': text, 'feature_vector': [round(float(x), 6) for x in feature_vector], 'feature_preview': [round(float(x), 6) for x in feature_preview], 'inference_time': round(inference_time, 4), 'status': 'success', 'message': '特征提取成功' }) except Exception as e: results.append({ 'text': text, 'feature_vector': [], 'feature_preview': [], 'inference_time': 0.0, 'status': 'error', 'message': f'特征提取失败:{str(e)}' }) total_time = time.time() - total_start_time return { 'results': results, 'total_time': round(total_time, 4), 'total_texts': len(results), 'status': 'success', 'message': f'批量处理完成,成功{len([r for r in results if r["status"] == "success"])}条,失败{len([r for r in results if r["status"] == "error"])}条' } except Exception as e: return { 'results': [], 'total_time': 0.0, 'total_texts': 0, 'status': 'error', 'message': f'批量处理失败:{str(e)}' }, 500 重要说明:上面的代码中,calculate_similarity_score()和extract_text_feature()是你原有的相似度计算和特征提取函数。你需要确保这些函数已经正确导入,或者将函数实现直接放在app.py中。
2.5 添加OpenAPI规范文件路由
为了让Swagger UI能够获取到API的描述信息,我们需要添加一个生成OpenAPI规范文件的路由。在app.py的最后(在if __name__ == '__main__':之前)添加:
# ========== 新增:生成OpenAPI规范文件 ========== @app.route('/api/swagger.json') def swagger_json(): """ 生成OpenAPI规范文件 Swagger UI通过这个端点获取API的完整描述信息。 """ return jsonify(api.__schema__) 2.6 测试Swagger集成
完成以上修改后,保存app.py文件,然后重启你的Flask应用:
# 如果你之前是直接运行python app.py python app.py # 或者如果你使用gunicorn等WSGI服务器 gunicorn -w 4 -b 0.0.0.0:6007 app:app 应用启动后,打开浏览器,访问以下地址:
- Web界面:
http://你的服务器IP:6007 - Swagger API文档:
http://你的服务器IP:6007/api/docs
你应该能看到一个美观的Swagger UI界面,左侧是API接口列表,右侧是接口的详细信息和测试界面。
3. 编写API测试用例
有了Swagger文档,我们还需要编写测试用例来确保API的稳定性和正确性。这里我们使用Python的unittest框架和requests库来编写测试。
3.1 创建测试目录和文件
在你的项目根目录下创建测试目录和文件:
mkdir tests touch tests/__init__.py touch tests/test_api.py 3.2 编写基础测试类
打开tests/test_api.py,编写基础测试类:
import unittest import json import time import requests class TestStructBERTAPI(unittest.TestCase): """StructBERT API测试类""" # API基础URL - 根据你的实际部署地址修改 BASE_URL = "http://localhost:6007" # 测试用的文本数据 TEST_TEXT1 = "今天天气真好,阳光明媚" TEST_TEXT2 = "天气晴朗,万里无云" TEST_TEXT3 = "我喜欢吃苹果和香蕉" TEST_TEXT4 = "水果中我最喜欢苹果" def setUp(self): """测试前的准备工作""" self.session = requests.Session() # 设置请求超时时间 self.timeout = 30 def tearDown(self): """测试后的清理工作""" self.session.close() def test_01_api_health(self): """测试API服务是否正常""" print("\n=== 测试API服务健康状态 ===") try: response = self.session.get(f"{self.BASE_URL}/", timeout=self.timeout) self.assertEqual(response.status_code, 200) print("✓ API服务健康检查通过") except Exception as e: self.fail(f"API服务不可用: {str(e)}") def test_02_swagger_ui(self): """测试Swagger UI是否可访问""" print("\n=== 测试Swagger UI ===") try: response = self.session.get(f"{self.BASE_URL}/api/docs", timeout=self.timeout) self.assertEqual(response.status_code, 200) print("✓ Swagger UI可正常访问") except Exception as e: self.fail(f"Swagger UI不可访问: {str(e)}") def test_03_openapi_spec(self): """测试OpenAPI规范文件是否可访问""" print("\n=== 测试OpenAPI规范文件 ===") try: response = self.session.get(f"{self.BASE_URL}/api/swagger.json", timeout=self.timeout) self.assertEqual(response.status_code, 200) # 验证返回的是有效的JSON spec = response.json() self.assertIn('openapi', spec) self.assertIn('info', spec) self.assertIn('paths', spec) print("✓ OpenAPI规范文件有效") except Exception as e: self.fail(f"OpenAPI规范文件无效: {str(e)}") 3.3 编写语义相似度计算测试
继续在test_api.py中添加相似度计算测试:
def test_04_similarity_calculate_basic(self): """测试基础语义相似度计算""" print("\n=== 测试基础语义相似度计算 ===") url = f"{self.BASE_URL}/api/similarity/calculate" headers = {'Content-Type': 'application/json'} # 测试用例1:相似文本 print("测试用例1:相似文本(天气相关)") data = { "text1": self.TEST_TEXT1, "text2": self.TEST_TEXT2 } try: start_time = time.time() response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) elapsed_time = time.time() - start_time self.assertEqual(response.status_code, 200) result = response.json() # 验证返回字段 self.assertIn('similarity_score', result) self.assertIn('similarity_level', result) self.assertIn('inference_time', result) self.assertIn('status', result) self.assertEqual(result['status'], 'success') # 验证相似度得分在合理范围内 self.assertGreaterEqual(result['similarity_score'], 0) self.assertLessEqual(result['similarity_score'], 1) # 验证推理时间合理 self.assertGreater(result['inference_time'], 0) self.assertLess(result['inference_time'], 5) # 假设单次推理不超过5秒 print(f" 相似度得分: {result['similarity_score']}") print(f" 相似度等级: {result['similarity_level']}") print(f" 推理时间: {result['inference_time']}秒") print(f" 请求总耗时: {elapsed_time:.2f}秒") except Exception as e: self.fail(f"相似文本测试失败: {str(e)}") # 测试用例2:不相似文本 print("\n测试用例2:不相似文本(天气 vs 水果)") data = { "text1": self.TEST_TEXT1, "text2": self.TEST_TEXT3 } try: response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) self.assertEqual(response.status_code, 200) result = response.json() # 不相似文本的得分应该较低 self.assertLess(result['similarity_score'], 0.5) print(f" 不相似文本得分: {result['similarity_score']} (预期<0.5)") except Exception as e: self.fail(f"不相似文本测试失败: {str(e)}") # 测试用例3:空文本 print("\n测试用例3:空文本处理") data = { "text1": "", "text2": "测试文本" } try: response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) # 空文本应该返回400错误 self.assertEqual(response.status_code, 400) result = response.json() self.assertEqual(result['status'], 'error') print(" 空文本正确处理,返回错误状态") except Exception as e: self.fail(f"空文本测试失败: {str(e)}") def test_05_similarity_calculate_with_threshold(self): """测试带阈值的语义相似度计算""" print("\n=== 测试带阈值的语义相似度计算 ===") url = f"{self.BASE_URL}/api/similarity/calculate" headers = {'Content-Type': 'application/json'} # 测试自定义阈值 data = { "text1": self.TEST_TEXT4, "text2": "苹果是一种常见的水果", "threshold_high": 0.8, # 提高高相似度阈值 "threshold_low": 0.4 # 提高低相似度阈值 } try: response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) self.assertEqual(response.status_code, 200) result = response.json() print(f" 文本1: {data['text1']}") print(f" 文本2: {data['text2']}") print(f" 相似度得分: {result['similarity_score']}") print(f" 相似度等级: {result['similarity_level']}") print(f" 使用阈值: 高={data['threshold_high']}, 低={data['threshold_low']}") # 验证阈值生效 score = result['similarity_score'] if score >= 0.8: self.assertEqual(result['similarity_level'], 'high') elif score >= 0.4: self.assertEqual(result['similarity_level'], 'medium') else: self.assertEqual(result['similarity_level'], 'low') print(" 阈值设置生效") except Exception as e: self.fail(f"带阈值测试失败: {str(e)}") 3.4 编写特征提取测试
继续添加特征提取测试:
def test_06_feature_extract_single(self): """测试单文本特征提取""" print("\n=== 测试单文本特征提取 ===") url = f"{self.BASE_URL}/api/feature/extract" headers = {'Content-Type': 'application/json'} # 测试正常文本 data = { "text": "深度学习是人工智能的一个重要分支" } try: start_time = time.time() response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) elapsed_time = time.time() - start_time self.assertEqual(response.status_code, 200) result = response.json() # 验证返回字段 self.assertIn('feature_vector', result) self.assertIn('feature_preview', result) self.assertIn('inference_time', result) self.assertIn('status', result) self.assertEqual(result['status'], 'success') # 验证特征向量维度 self.assertEqual(len(result['feature_vector']), 768) self.assertEqual(len(result['feature_preview']), 20) # 验证预览是前20维 self.assertEqual(result['feature_preview'], result['feature_vector'][:20]) print(f" 输入文本: {data['text']}") print(f" 特征向量维度: {len(result['feature_vector'])}") print(f" 特征预览维度: {len(result['feature_preview'])}") print(f" 推理时间: {result['inference_time']}秒") print(f" 请求总耗时: {elapsed_time:.2f}秒") # 验证特征值在合理范围内 for value in result['feature_preview']: self.assertIsInstance(value, float) # BERT特征通常在一定范围内 self.assertGreater(value, -10) self.assertLess(value, 10) print(" 特征值范围验证通过") except Exception as e: self.fail(f"单文本特征提取测试失败: {str(e)}") # 测试空文本 print("\n测试空文本处理") data = {"text": ""} try: response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) self.assertEqual(response.status_code, 400) result = response.json() self.assertEqual(result['status'], 'error') print(" 空文本正确处理,返回错误状态") except Exception as e: self.fail(f"空文本测试失败: {str(e)}") def test_07_feature_extract_batch(self): """测试批量特征提取""" print("\n=== 测试批量特征提取 ===") url = f"{self.BASE_URL}/api/feature/batch-extract" headers = {'Content-Type': 'application/json'} # 准备测试数据 batch_texts = [ "今天天气真好,适合出去散步", "人工智能正在改变世界", "机器学习是人工智能的核心技术", "自然语言处理让计算机理解人类语言", "计算机视觉让机器看懂世界" ] data = { "batch_texts": batch_texts } try: start_time = time.time() response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) elapsed_time = time.time() - start_time self.assertEqual(response.status_code, 200) result = response.json() # 验证返回字段 self.assertIn('results', result) self.assertIn('total_time', result) self.assertIn('total_texts', result) self.assertIn('status', result) self.assertEqual(result['status'], 'success') # 验证结果数量 self.assertEqual(len(result['results']), len(batch_texts)) self.assertEqual(result['total_texts'], len(batch_texts)) print(f" 批量文本数量: {len(batch_texts)}") print(f" 处理结果数量: {len(result['results'])}") print(f" 总处理时间: {result['total_time']}秒") print(f" 请求总耗时: {elapsed_time:.2f}秒") # 验证每个结果 success_count = 0 for i, item in enumerate(result['results']): if item['status'] == 'success': success_count += 1 self.assertEqual(len(item['feature_vector']), 768) self.assertEqual(len(item['feature_preview']), 20) self.assertEqual(item['feature_preview'], item['feature_vector'][:20]) else: print(f" 第{i+1}条文本处理失败: {item['message']}") print(f" 成功处理: {success_count}/{len(batch_texts)}") # 验证批量处理的效率 # 批量处理总时间应该小于单条处理时间之和 self.assertLess(result['total_time'], 5 * 2) # 假设单条最多2秒,5条应该小于10秒 except Exception as e: self.fail(f"批量特征提取测试失败: {str(e)}") # 测试空批量 print("\n测试空批量处理") data = {"batch_texts": []} try: response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) self.assertEqual(response.status_code, 400) result = response.json() self.assertEqual(result['status'], 'error') print(" 空批量正确处理,返回错误状态") except Exception as e: self.fail(f"空批量测试失败: {str(e)}") # 测试无效格式 print("\n测试无效格式处理") data = {"batch_texts": "这不是一个列表"} try: response = self.session.post(url, json=data, headers=headers, timeout=self.timeout) self.assertEqual(response.status_code, 400) result = response.json() self.assertEqual(result['status'], 'error') print(" 无效格式正确处理,返回错误状态") except Exception as e: self.fail(f"无效格式测试失败: {str(e)}") 3.5 编写性能测试
添加性能测试用例:
def test_08_performance_test(self): """测试API性能""" print("\n=== 测试API性能 ===") similarity_url = f"{self.BASE_URL}/api/similarity/calculate" feature_url = f"{self.BASE_URL}/api/feature/extract" headers = {'Content-Type': 'application/json'} # 准备测试数据 test_cases = [ {"text1": "苹果手机", "text2": "iPhone"}, {"text1": "今天天气不错", "text2": "天气很好"}, {"text1": "人工智能", "text2": "AI技术"}, {"text1": "机器学习", "text2": "深度学习"}, {"text1": "自然语言处理", "text2": "NLP"} ] # 测试相似度计算性能 print("测试相似度计算性能(5次请求):") similarity_times = [] for i, test_case in enumerate(test_cases, 1): try: start_time = time.time() response = self.session.post( similarity_url, json=test_case, headers=headers, timeout=self.timeout ) elapsed_time = time.time() - start_time if response.status_code == 200: result = response.json() similarity_times.append(result['inference_time']) print(f" 第{i}次: {result['inference_time']:.4f}秒 (相似度: {result['similarity_score']:.4f})") else: print(f" 第{i}次: 请求失败") except Exception as e: print(f" 第{i}次: 异常 - {str(e)}") if similarity_times: avg_time = sum(similarity_times) / len(similarity_times) max_time = max(similarity_times) min_time = min(similarity_times) print(f" 平均推理时间: {avg_time:.4f}秒") print(f" 最长推理时间: {max_time:.4f}秒") print(f" 最短推理时间: {min_time:.4f}秒") # 性能断言:平均推理时间应小于1秒 self.assertLess(avg_time, 1.0, "相似度计算平均时间超过1秒") # 测试特征提取性能 print("\n测试特征提取性能(5次请求):") feature_times = [] test_texts = [ "这是一个测试文本,用于特征提取性能测试", "深度学习模型需要大量的计算资源", "自然语言处理是人工智能的重要方向", "计算机视觉可以识别图像中的物体", "机器学习算法可以从数据中学习规律" ] for i, text in enumerate(test_texts, 1): try: start_time = time.time() response = self.session.post( feature_url, json={"text": text}, headers=headers, timeout=self.timeout ) elapsed_time = time.time() - start_time if response.status_code == 200: result = response.json() feature_times.append(result['inference_time']) print(f" 第{i}次: {result['inference_time']:.4f}秒") else: print(f" 第{i}次: 请求失败") except Exception as e: print(f" 第{i}次: 异常 - {str(e)}") if feature_times: avg_time = sum(feature_times) / len(feature_times) max_time = max(feature_times) min_time = min(feature_times) print(f" 平均推理时间: {avg_time:.4f}秒") print(f" 最长推理时间: {max_time:.4f}秒") print(f" 最短推理时间: {min_time:.4f}秒") # 性能断言:平均推理时间应小于1秒 self.assertLess(avg_time, 1.0, "特征提取平均时间超过1秒") 3.6 添加测试运行入口
在test_api.py文件末尾添加:
def run_tests(): """运行所有测试""" print("=" * 60) print("StructBERT API 测试开始") print("=" * 60) # 创建测试套件 loader = unittest.TestLoader() suite = loader.loadTestsFromTestCase(TestStructBERTAPI) # 运行测试 runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) print("=" * 60) print("测试完成") print(f"运行测试: {result.testsRun}") print(f"失败测试: {len(result.failures)}") print(f"错误测试: {len(result.errors)}") print("=" * 60) return result.wasSuccessful() if __name__ == '__main__': # 检查API服务是否可用 try: response = requests.get("http://localhost:6007/", timeout=5) if response.status_code == 200: print("✓ API服务运行正常,开始测试...") success = run_tests() sys.exit(0 if success else 1) else: print("✗ API服务返回异常状态码") sys.exit(1) except requests.exceptions.ConnectionError: print("✗ 无法连接到API服务,请确保服务已启动") print(" 运行命令: python app.py 或 gunicorn -w 4 -b 0.0.0.0:6007 app:app") sys.exit(1) except Exception as e: print(f"✗ 检查API服务时发生错误: {str(e)}") sys.exit(1) 3.7 运行测试
确保你的StructBERT服务正在运行,然后在终端中执行:
cd /path/to/your/project python -m tests.test_api 或者直接运行:
python tests/test_api.py 你应该能看到详细的测试输出,包括每个测试用例的执行结果和性能数据。
4. 总结与最佳实践
4.1 本文实现的核心价值
通过本文的实践,我们为StructBERT中文语义智能匹配系统实现了两个重要的工程化改进:
- 自动化API文档生成:通过集成Swagger(OpenAPI),我们实现了API接口的自动文档化。现在,任何开发者都可以通过
http://你的服务器:6007/api/docs访问完整的、交互式的API文档,无需手动编写和维护文档。 - 全面的测试覆盖:我们编写了完整的测试套件,覆盖了:
- API服务健康检查
- Swagger UI可访问性
- 语义相似度计算(基础功能、阈值设置、异常处理)
- 特征提取(单文本、批量处理、异常处理)
- 性能测试(响应时间、并发处理)
4.2 实际使用建议
在实际项目中,我建议你:
- 将测试集成到CI/CD流程:将测试用例添加到你的持续集成流程中,每次代码更新都自动运行测试,确保API的稳定性。
- 定期更新测试数据:随着业务发展,更新测试用例中的文本数据,确保测试覆盖真实的业务场景。
- 监控API性能:定期运行性能测试,监控API的响应时间变化,及时发现性能瓶颈。
- 扩展测试覆盖:根据实际需求,可以添加更多测试用例,比如:
- 并发请求测试
- 长文本处理测试
- 特殊字符处理测试
- 模型精度验证测试
- 文档维护:虽然Swagger可以自动生成文档,但接口的详细说明、使用示例、业务场景等还需要人工维护。建议在代码注释中保持详细的说明。
4.3 遇到的常见问题与解决
在实施过程中,你可能会遇到以下问题:
- Swagger UI无法访问:检查Flask蓝图注册是否正确,确保
/api/docs路由被正确添加。 - API文档不完整:检查Flask-RESTx的装饰器是否正确添加,特别是
@api.expect()和@api.marshal_with()。 - 测试连接失败:确保测试脚本中的BASE_URL配置正确,且API服务正在运行。
- 性能测试失败:如果性能测试不通过,可以考虑:
- 优化模型加载和推理代码
- 添加缓存机制
- 使用GPU加速(如果可用)
- 调整批量处理的大小
4.4 下一步优化方向
完成基础集成后,你还可以考虑以下优化:
- 添加API版本管理:随着系统迭代,可能需要维护多个API版本,Flask-RESTx支持命名空间版本控制。
- 实现API认证:如果需要对API进行访问控制,可以添加JWT Token认证或API Key认证。
- 生成客户端SDK:利用Swagger的代码生成功能,自动生成Python、JavaScript等语言的客户端SDK。
- 添加API使用统计:记录API的调用情况,分析使用频率和性能指标。
- 实现自动化测试报告:将测试结果生成HTML报告,方便查看和分析。
通过本文的实践,你的StructBERT系统不仅功能强大,而且具备了专业级的API文档和测试保障,为团队协作和系统维护打下了坚实基础。现在,其他开发者可以轻松理解和使用你的API,你也可以自信地进行系统迭代和优化了。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 ZEEKLOG星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。