commit ecb3e1d9b233a451c34891a39207197c27c065ac Author: LUOJIE\coolp Date: Wed Mar 4 12:17:52 2026 +0800 first commit diff --git a/.env b/.env new file mode 100644 index 0000000..17a5fcd --- /dev/null +++ b/.env @@ -0,0 +1,33 @@ +# 数据库接口服务配置文件 +# 复制此文件为 .env 并修改相应的配置值 + +# 是否启用示例数据初始化 (true/false) +ENABLE_SAMPLE_DATA=true + +# MySQL数据库配置 +MYSQL_HOST=192.168.13.27 +MYSQL_PORT=18903 +MYSQL_USERNAME=luojie +MYSQL_PASSWORD=123456 +MYSQL_DATABASE=testdb + +# Oracle数据库配置 +ORACLE_HOST=192.168.13.27 +ORACLE_PORT=1521 +ORACLE_USERNAME=bizuser +ORACLE_PASSWORD=MySecurePass123 +ORACLE_SERVICE_NAME=ORCLPDB1 + +# SQL Server数据库配置 +SQLSERVER_HOST=192.168.11.200 +SQLSERVER_PORT=1433 +SQLSERVER_USERNAME=sa +SQLSERVER_PASSWORD=sqlserver@7740 +SQLSERVER_DATABASE=test + +# PostgreSQL数据库配置 +POSTGRESQL_HOST=localhost +POSTGRESQL_PORT=5432 +POSTGRESQL_USERNAME=postgres +POSTGRESQL_PASSWORD=password +POSTGRESQL_DATABASE=postgres \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..3be59a2 --- /dev/null +++ b/.env.example @@ -0,0 +1,33 @@ +# 数据库接口服务配置文件 +# 复制此文件为 .env 并修改相应的配置值 + +# 是否启用示例数据初始化 (true/false) +ENABLE_SAMPLE_DATA=true + +# MySQL数据库配置 +MYSQL_HOST=localhost +MYSQL_PORT=3306 +MYSQL_USERNAME=root +MYSQL_PASSWORD=password +MYSQL_DATABASE=test_db + +# Oracle数据库配置 +ORACLE_HOST=192.168.13.27 +ORACLE_PORT=1521 +ORACLE_USERNAME=bizuser +ORACLE_PASSWORD=MySecurePass123 +ORACLE_SERVICE_NAME=ORCLPDB1 + +# SQL Server数据库配置 +SQLSERVER_HOST=localhost +SQLSERVER_PORT=1433 +SQLSERVER_USERNAME=sa +SQLSERVER_PASSWORD=password +SQLSERVER_DATABASE=master + +# PostgreSQL数据库配置 +POSTGRESQL_HOST=localhost +POSTGRESQL_PORT=5432 +POSTGRESQL_USERNAME=postgres +POSTGRESQL_PASSWORD=password +POSTGRESQL_DATABASE=postgres \ No newline at end of file diff --git a/.trae/rules/project_rules.md b/.trae/rules/project_rules.md new file mode 100644 index 0000000..001c031 --- /dev/null +++ b/.trae/rules/project_rules.md @@ -0,0 +1 @@ +1. 通过fastapi启动api服务,使用sqlalchemy来创建连接引擎,对不同的数据库使用不同的驱动,mysql使用PyMySQL,Oracle使用oracledb,sqlserver使用pymssql,postgresql使用psycopg2 diff --git a/API_USAGE.md b/API_USAGE.md new file mode 100644 index 0000000..efb05b2 --- /dev/null +++ b/API_USAGE.md @@ -0,0 +1,266 @@ +# 数据库接口服务 API 使用说明 + +## 项目启动 + +1. 安装依赖: +```bash +pip install -r requirements.txt +``` + +2. 启动服务: +```bash +python main.py +``` + +3. 访问API文档: +- Swagger UI: http://localhost:8000/docs +- ReDoc: http://localhost:8000/redoc + +## 主要功能 + +### 1. 数据库连接管理 + +#### 创建连接 +```http +POST /api/v1/connections +Content-Type: application/json + +{ + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "username": "root", + "password": "password", + "database": "test_db" +} +``` + +#### 获取所有连接 +```http +GET /api/v1/connections +``` + +#### 关闭连接(POST,JSON传参) +```http +POST /api/v1/connections/close +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db" +} +``` + +### 2. 数据库信息查询 + +#### 获取数据库信息(使用query参数) +```http +GET /api/v1/databases/info?connection_id= +``` + +#### 获取表信息(使用query参数) +```http +GET /api/v1/databases/tables/info?connection_id=&table_name= +``` + +### 3. SQL执行 + +#### 执行查询SQL +```http +POST /api/v1/query +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "sql": "SELECT * FROM users WHERE age > :age", + "params": {"age": 18} +} +``` + +#### 执行非查询SQL +```http +POST /api/v1/execute +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "sql": "UPDATE users SET name = :name WHERE id = :id", + "params": {"name": "新名称", "id": 1} +} +``` + +### 4. 表数据CRUD操作 + +#### 查询表数据 +```http +POST /api/v1/tables/data/select +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users", + "page": 1, + "page_size": 10, + "where_clause": "age > 18", + "order_by": "id DESC" +} +``` + +#### 插入数据 +```http +POST /api/v1/tables/data/insert +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users", + "data": { + "name": "张三", + "age": 25, + "email": "zhangsan@example.com" + } +} +``` + +#### 更新数据(改为POST) +```http +POST /api/v1/tables/data/update +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users", + "data": { + "name": "李四", + "age": 30 + }, + "where_clause": "id = 1" +} +``` + +#### 删除数据(改为POST) +```http +POST /api/v1/tables/data/delete +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users", + "where_clause": "id = 1" +} +``` + +### 5. 表结构管理 + +#### 创建表 +```http +POST /api/v1/tables/create +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "new_table", + "columns": [ + { + "name": "id", + "type": "INT", + "primary_key": true, + "not_null": true + }, + { + "name": "name", + "type": "VARCHAR(100)", + "not_null": true, + "comment": "用户名称" + } + ] +} +``` + +#### 删除表(POST,JSON传参) +```http +POST /api/v1/tables/delete +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users" +} +``` + +#### 修改表结构(改为POST) +```http +POST /api/v1/tables/alter +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users", + "operation": "ADD", + "column_definition": { + "name": "phone", + "type": "VARCHAR(20)", + "not_null": false + } +} +``` + +### 6. 备注管理 + +#### 修改表或字段备注(改为POST) +```http +POST /api/v1/tables/comment +Content-Type: application/json + +{ + "connection_id": "mysql_localhost_3306_test_db", + "table_name": "users", + "column_name": "name", + "comment": "用户姓名字段" +} +``` + +### 7. 其他GET接口(统一使用query参数) + +#### 获取数据库中的所有表 +```http +GET /api/v1/tables?connection_id= +``` + +#### 获取表的所有字段信息 +```http +GET /api/v1/tables/columns?connection_id=&table_name=
+``` +``` + +## 支持的数据库类型 + +- **MySQL**: 使用 PyMySQL 驱动 +- **Oracle**: 使用 oracledb 驱动 +- **SQL Server**: 使用 pymssql 驱动 +- **PostgreSQL**: 使用 psycopg2 驱动 + +## 响应格式 + +所有API接口都遵循统一的响应格式: + +```json +{ + "success": true, + "message": "操作成功", + "data": {}, + "error": null +} +``` + +- `success`: 布尔值,表示操作是否成功 +- `message`: 字符串,操作结果描述 +- `data`: 对象,返回的数据(可选) +- `error`: 字符串,错误信息(仅在失败时存在) + +## 注意事项 + +1. 连接ID格式:`{db_type}_{host}_{port}_{database}` +2. SQL参数使用命名参数格式,如 `:param_name` +3. 所有接口都支持CORS跨域访问 +4. 服务启动后会自动管理数据库连接池 +5. 应用关闭时会自动清理所有数据库连接 diff --git a/ORACLE_CONNECTION_GUIDE.md b/ORACLE_CONNECTION_GUIDE.md new file mode 100644 index 0000000..14939be --- /dev/null +++ b/ORACLE_CONNECTION_GUIDE.md @@ -0,0 +1,186 @@ +# Oracle数据库连接指南 + +## Oracle连接角色说明 + +### 什么是Oracle角色? + +Oracle数据库中的角色(Role)是一组权限的集合,用于简化用户权限管理。角色不是连接参数,而是数据库内部的权限管理机制。 + +### 常见的Oracle角色 + +1. **CONNECT** - 基本连接权限 + - 允许用户连接到数据库 + - 创建表、视图、序列等基本对象 + +2. **RESOURCE** - 资源使用权限 + - 允许用户创建存储过程、触发器等 + - 使用表空间资源 + +3. **DBA** - 数据库管理员权限 + - 完全的数据库管理权限 + - 可以管理所有数据库对象 + +4. **SELECT_CATALOG_ROLE** - 数据字典查询权限 + - 允许查询数据字典视图 + +### 默认角色 + +对于普通用户(如 `bizuser`),通常会被授予以下默认角色: +- **CONNECT** - 基本连接权限 +- **RESOURCE** - 资源使用权限 + +## Oracle连接参数说明 + +### 基本连接参数 + +- **主机地址 (host)**: Oracle数据库服务器的IP地址或主机名 +- **端口 (port)**: Oracle监听器端口,默认为1521 +- **服务名称 (service_name)**: Oracle数据库的服务名称,如ORCLPDB1 +- **用户名 (username)**: 数据库用户名 +- **密码 (password)**: 数据库密码 + +### 高级连接参数 + +- **mode**: 连接模式 + - `SYSDBA`: 系统管理员模式 + - `SYSOPER`: 系统操作员模式 + - `NORMAL`: 普通用户模式(默认) + +- **threaded**: 线程模式 + - `true`: 启用线程模式(推荐) + - `false`: 禁用线程模式 + +## 连接字符串格式 + +### 标准格式 +``` +oracle+oracledb://username:password@host:port/service_name +``` + +### 带参数格式 +``` +oracle+oracledb://username:password@host:port/service_name?encoding=UTF-8&nencoding=UTF-8&threaded=true +``` + +## 用户提供的连接信息 + +根据您提供的Navicat连接信息: + +``` +主机地址: 192.168.13.27 +端口: 1521 +服务名称: ORCLPDB1 +用户名: bizuser +密码: MySecurePass123 +角色: Default (CONNECT + RESOURCE) +``` + +## 连接测试 + +### 使用测试脚本 + +1. 启动API服务: + ```bash + python main.py + ``` + +2. 运行Oracle连接测试: + ```bash + python test_oracle_connection.py + ``` + +### 使用API接口 + +```json +{ + "db_type": "oracle", + "host": "192.168.13.27", + "port": 1521, + "username": "bizuser", + "password": "MySecurePass123", + "database": "ORCLPDB1", + "threaded": true +} +``` + +## 常见连接问题及解决方案 + +### 1. TNS: 无法解析指定的连接标识符 + +**原因**: 服务名称不正确或Oracle监听器未启动 + +**解决方案**: +- 检查服务名称是否正确 +- 确认Oracle监听器正在运行 +- 使用 `lsnrctl status` 检查监听器状态 + +### 2. ORA-12541: TNS: 无监听程序 + +**原因**: Oracle监听器未启动或端口被占用 + +**解决方案**: +- 启动Oracle监听器: `lsnrctl start` +- 检查端口1521是否被占用 +- 确认防火墙设置 + +### 3. ORA-01017: 用户名/口令无效 + +**原因**: 用户名或密码错误 + +**解决方案**: +- 验证用户名和密码 +- 检查用户账户是否被锁定 +- 确认用户是否存在 + +### 4. ORA-12514: TNS: 监听程序当前无法识别连接描述符中请求的服务 + +**原因**: 服务名称不存在或未注册到监听器 + +**解决方案**: +- 检查服务名称是否正确 +- 使用 `lsnrctl services` 查看可用服务 +- 确认数据库实例正在运行 + +## 连接池配置 + +本项目为Oracle连接配置了专门的连接池参数: + +```python +engine_kwargs = { + "pool_size": 10, # 连接池大小 + "max_overflow": 20, # 最大溢出连接数 + "pool_timeout": 30, # 连接超时时间(秒) + "pool_recycle": 3600 # 连接回收时间(秒) +} +``` + +## 环境变量配置 + +可以通过环境变量配置Oracle连接参数: + +```bash +# .env文件 +ORACLE_HOST=192.168.13.27 +ORACLE_PORT=1521 +ORACLE_USERNAME=bizuser +ORACLE_PASSWORD=MySecurePass123 +ORACLE_SERVICE_NAME=ORCLPDB1 +``` + +## 注意事项 + +1. **角色不是连接参数**: Oracle角色是数据库内部的权限管理机制,不需要在连接字符串中指定 +2. **服务名称 vs SID**: 现代Oracle推荐使用服务名称而不是SID +3. **字符编码**: 建议使用UTF-8编码以支持中文字符 +4. **连接安全**: 在生产环境中,建议使用SSL/TLS加密连接 +5. **连接池**: 使用连接池可以提高性能和资源利用率 + +## 技术支持 + +如果连接仍然失败,请检查: + +1. Oracle客户端库是否正确安装 +2. 网络连接是否正常 +3. Oracle数据库服务是否运行 +4. 防火墙和安全组设置 +5. 用户权限是否足够 \ No newline at end of file diff --git a/ORACLE_OPTIMIZATION_GUIDE.md b/ORACLE_OPTIMIZATION_GUIDE.md new file mode 100644 index 0000000..7ee6cb4 --- /dev/null +++ b/ORACLE_OPTIMIZATION_GUIDE.md @@ -0,0 +1,238 @@ +# Oracle连接优化指南 + +## 概述 + +本文档详细说明了针对Oracle数据库连接失败问题的优化措施和解决方案。 + +## 问题描述 + +- **现象**: MySQL可以正常初始化,Oracle一直初始化失败,卡在连接失败 +- **原因**: Oracle连接参数配置不当,连接方式不符合oracledb库的最佳实践 + +## 优化措施 + +### 1. 数据库管理器优化 (`database_manager.py`) + +#### 1.1 连接URL构建优化 + +**优化前**: +```python +# 使用复杂的Easy Connect字符串 +dsn = f"{host}:{port}/{service_name}" +base_url = f"oracle+oracledb://{username}:{password}@{dsn}" +``` + +**优化后**: +```python +# 使用标准的SQLAlchemy URL格式 +base_url = f"oracle+oracledb://{username}:{password}@{host}:{port}/?service_name={service_name}" +``` + +#### 1.2 连接预测试机制 + +添加了多种连接方式的预测试: +- Easy Connect字符串 +- 分离参数连接 +- SID连接方式 + +#### 1.3 连接池配置优化 + +```python +engine_kwargs = { + "pool_size": 5, + "max_overflow": 10, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_reset_on_return": "commit" +} +``` + +### 2. 示例数据初始化优化 (`sample_data.py`) + +#### 2.1 增强日志记录 + +```python +logger.info(f"开始初始化Oracle示例数据,配置: host={config['host']}, port={config['port']}, service_name={config['service_name']}") +``` + +#### 2.2 连接参数传递 + +```python +conn_id = self.db_manager.create_connection( + db_type="oracle", + host=config["host"], + port=config["port"], + username=config["username"], + password=config["password"], + database=config["service_name"], + mode=config.get("mode"), + threaded=config.get("threaded", True) +) +``` + +### 3. 测试工具 + +#### 3.1 增强版连接测试 (`test_oracle_connection.py`) + +- 直接Oracle连接测试 +- 多种连接方式尝试 +- 详细的错误诊断 +- API连接测试 + +#### 3.2 快速诊断工具 (`quick_oracle_test.py`) + +- 4种不同的连接方式测试 +- 常见错误代码分析 +- 解决方案建议 + +## 连接方式说明 + +### 方式1: Easy Connect字符串 +```python +dsn = f"{host}:{port}/{service_name}" +oracledb.connect(user=username, password=password, dsn=dsn) +``` + +### 方式2: 分离参数 +```python +oracledb.connect( + user=username, + password=password, + host=host, + port=port, + service_name=service_name +) +``` + +### 方式3: makedsn (Service Name) +```python +dsn = oracledb.makedsn(host, port, service_name=service_name) +oracledb.connect(user=username, password=password, dsn=dsn) +``` + +### 方式4: makedsn (SID) +```python +dsn = oracledb.makedsn(host, port, sid=service_name) +oracledb.connect(user=username, password=password, dsn=dsn) +``` + +## 常见错误及解决方案 + +### ORA-12514: TNS:listener does not currently know of service requested + +**原因**: 服务名不存在或未注册到监听器 + +**解决方案**: +1. 检查服务名是否正确 +2. 确认服务已注册到监听器: `lsnrctl services` +3. 尝试使用SID而不是服务名 + +### ORA-12541: TNS:no listener + +**原因**: 监听器未运行 + +**解决方案**: +1. 启动监听器: `lsnrctl start` +2. 检查监听器状态: `lsnrctl status` + +### ORA-01017: invalid username/password + +**原因**: 用户名或密码错误 + +**解决方案**: +1. 验证用户名和密码 +2. 检查用户是否存在且有连接权限 + +### ORA-12170: TNS:Connect timeout occurred + +**原因**: 连接超时 + +**解决方案**: +1. 检查网络连接 +2. 检查防火墙设置 +3. 增加连接超时时间 + +## 配置文件更新 + +### config.py +```python +"oracle": { + "host": "192.168.1.100", + "port": 1521, + "username": "c##testuser", + "password": "123456", + "service_name": "XEPDB1" +} +``` + +### .env.example +``` +ORACLE_HOST=192.168.1.100 +ORACLE_PORT=1521 +ORACLE_USERNAME=c##testuser +ORACLE_PASSWORD=123456 +ORACLE_SERVICE_NAME=XEPDB1 +``` + +## 测试步骤 + +### 1. 快速连接测试 +```bash +python quick_oracle_test.py +``` + +### 2. 完整功能测试 +```bash +python test_oracle_connection.py +``` + +### 3. 启动API服务测试 +```bash +python main.py +``` + +## 最佳实践 + +1. **使用标准的SQLAlchemy URL格式**,避免复杂的DSN构建 +2. **实施连接预测试**,在创建SQLAlchemy引擎前验证连接参数 +3. **配置适当的连接池参数**,提高连接性能和稳定性 +4. **添加详细的日志记录**,便于问题诊断 +5. **提供多种连接方式**,增加连接成功率 +6. **实施错误分析和建议**,帮助快速定位问题 + +## 性能优化 + +### 连接池配置 +- `pool_size`: 5 (基础连接数) +- `max_overflow`: 10 (最大溢出连接数) +- `pool_timeout`: 30秒 (获取连接超时) +- `pool_recycle`: 3600秒 (连接回收时间) +- `pool_reset_on_return`: "commit" (连接返回时重置) + +### SQLAlchemy配置 +- `echo`: False (不输出SQL语句) +- `pool_pre_ping`: True (连接前ping测试) + +## 故障排除清单 + +- [ ] Oracle数据库服务是否运行 +- [ ] 监听器是否启动 (`lsnrctl status`) +- [ ] 服务名是否正确注册 (`lsnrctl services`) +- [ ] 网络连接是否正常 +- [ ] 防火墙是否阻止连接 +- [ ] 用户名和密码是否正确 +- [ ] 用户是否有连接权限 +- [ ] oracledb库是否正确安装 (`pip install oracledb`) +- [ ] 配置文件参数是否正确 + +## 总结 + +通过以上优化措施,Oracle连接的稳定性和成功率得到显著提升。主要改进包括: + +1. **简化连接URL构建**,使用标准格式 +2. **增加连接预测试**,提前发现问题 +3. **优化连接池配置**,提高性能 +4. **提供多种测试工具**,便于诊断 +5. **完善错误处理**,提供解决建议 + +这些优化措施确保了Oracle数据库连接的可靠性,解决了初始化失败的问题。 \ No newline at end of file diff --git a/SAMPLE_DATA.md b/SAMPLE_DATA.md new file mode 100644 index 0000000..4bb1343 --- /dev/null +++ b/SAMPLE_DATA.md @@ -0,0 +1,173 @@ +# 示例数据说明 + +本项目在启动时会自动初始化MySQL和Oracle数据库的示例数据,方便用户测试和体验API功能。 + +## 配置说明 + +### 环境变量配置 + +1. 复制 `.env.example` 文件为 `.env` +2. 根据你的数据库环境修改相应的配置 +3. 设置 `ENABLE_SAMPLE_DATA=true` 启用示例数据初始化 + +### 默认配置 + +如果没有设置环境变量,系统将使用以下默认配置: + +- **MySQL**: localhost:3306, 用户名: root, 密码: password, 数据库: test_db +- **Oracle**: localhost:1521, 用户名: system, 密码: password, 服务名: XE + +## MySQL 示例数据 + +### 1. users 表(用户信息表) + +| 字段名 | 类型 | 说明 | 示例数据 | +|--------|------|------|----------| +| id | INT | 用户ID(主键,自增) | 1, 2, 3... | +| name | VARCHAR(100) | 用户姓名 | 张三, 李四, 王五... | +| email | VARCHAR(150) | 邮箱地址(唯一) | zhangsan@example.com | +| age | INT | 年龄 | 25, 30, 28... | +| created_at | TIMESTAMP | 创建时间 | 当前时间 | +| updated_at | TIMESTAMP | 更新时间 | 当前时间 | + +**示例记录:** +- 张三, zhangsan@example.com, 25岁 +- 李四, lisi@example.com, 30岁 +- 王五, wangwu@example.com, 28岁 +- 赵六, zhaoliu@example.com, 35岁 +- 钱七, qianqi@example.com, 22岁 + +### 2. products 表(产品信息表) + +| 字段名 | 类型 | 说明 | 示例数据 | +|--------|------|------|----------| +| id | INT | 产品ID(主键,自增) | 1, 2, 3... | +| name | VARCHAR(200) | 产品名称 | 苹果手机, 笔记本电脑... | +| price | DECIMAL(10,2) | 价格 | 5999.00, 8999.00... | +| category | VARCHAR(100) | 分类 | 电子产品, 生活用品 | +| description | TEXT | 产品描述 | 最新款智能手机... | +| stock_quantity | INT | 库存数量 | 50, 30, 100... | +| created_at | TIMESTAMP | 创建时间 | 当前时间 | + +**示例记录:** +- 苹果手机, ¥5999.00, 电子产品, 库存50 +- 笔记本电脑, ¥8999.00, 电子产品, 库存30 +- 无线耳机, ¥299.00, 电子产品, 库存100 +- 咖啡杯, ¥39.90, 生活用品, 库存200 +- 书包, ¥129.00, 生活用品, 库存80 + +## Oracle 示例数据 + +### 1. departments 表(部门信息表) + +| 字段名 | 类型 | 说明 | 示例数据 | +|--------|------|------|----------| +| department_id | NUMBER | 部门ID(主键) | 1, 2, 3... | +| department_name | VARCHAR2(100) | 部门名称 | 人力资源部, 技术部... | +| manager_id | NUMBER | 经理ID | NULL(暂未设置) | +| location_id | NUMBER | 位置ID | 1700, 1800, 1900 | + +**示例记录:** +- 人力资源部, 位置ID: 1700 +- 技术部, 位置ID: 1800 +- 销售部, 位置ID: 1900 + +### 2. employees 表(员工信息表) + +| 字段名 | 类型 | 说明 | 示例数据 | +|--------|------|------|----------| +| employee_id | NUMBER | 员工ID(主键) | 1, 2, 3... | +| first_name | VARCHAR2(50) | 名 | 三, 四, 五 | +| last_name | VARCHAR2(50) | 姓 | 张, 李, 王 | +| email | VARCHAR2(100) | 邮箱地址(唯一) | zhang.san@company.com | +| phone_number | VARCHAR2(20) | 电话号码 | 13800138001 | +| hire_date | DATE | 入职日期 | 当前日期 | +| job_id | VARCHAR2(10) | 职位ID | IT_PROG, SA_REP, HR_REP | +| salary | NUMBER(8,2) | 薪资 | 8000, 6000, 5500 | +| department_id | NUMBER | 部门ID | 1, 2, 3 | + +**示例记录:** +- 张三, zhang.san@company.com, IT程序员, 技术部, ¥8000 +- 李四, li.si@company.com, 销售代表, 销售部, ¥6000 +- 王五, wang.wu@company.com, 人事代表, 人力资源部, ¥5500 + +### 序列(Sequences) + +- `emp_seq`: 员工ID序列,从1开始递增 +- `dept_seq`: 部门ID序列,从1开始递增 + +## API 测试建议 + +### 1. 连接数据库 + +```bash +# 连接MySQL +POST /api/v1/connections +{ + "db_type": "mysql", + "host": "localhost", + "port": 3306, + "username": "root", + "password": "password", + "database": "test_db" +} + +# 连接Oracle +POST /api/v1/connections +{ + "db_type": "oracle", + "host": "localhost", + "port": 1521, + "username": "system", + "password": "password", + "database": "XE" +} +``` + +### 2. 查询示例数据 + +```bash +# 查询MySQL用户数据 +POST /api/v1/query +{ + "connection_id": "your_mysql_connection_id", + "sql": "SELECT * FROM users LIMIT 10" +} + +# 查询Oracle员工数据 +POST /api/v1/query +{ + "connection_id": "your_oracle_connection_id", + "sql": "SELECT * FROM employees WHERE ROWNUM <= 10" +} +``` + +### 3. 获取表信息 + +```bash +# 获取MySQL数据库表列表 +GET /api/v1/databases/{connection_id}/tables + +# 获取Oracle表结构 +GET /api/v1/tables/{connection_id}/employees/columns +``` + +## 注意事项 + +1. **数据库权限**: 确保配置的数据库用户具有创建表、插入数据的权限 +2. **数据库连接**: 确保数据库服务正在运行且可以连接 +3. **重复初始化**: 示例数据使用 `INSERT IGNORE`(MySQL)避免重复插入 +4. **Oracle序列**: Oracle示例数据会先删除已存在的表和序列,然后重新创建 +5. **禁用示例数据**: 设置环境变量 `ENABLE_SAMPLE_DATA=false` 可以禁用示例数据初始化 + +## 故障排除 + +如果示例数据初始化失败,请检查: + +1. 数据库服务是否正常运行 +2. 连接参数是否正确 +3. 数据库用户是否有足够的权限 +4. 网络连接是否正常 +5. 查看应用日志获取详细错误信息 + +初始化失败不会影响服务启动,你仍然可以手动创建数据库连接和数据。 \ No newline at end of file diff --git a/SQLSERVER_SETUP_GUIDE.md b/SQLSERVER_SETUP_GUIDE.md new file mode 100644 index 0000000..a91e310 --- /dev/null +++ b/SQLSERVER_SETUP_GUIDE.md @@ -0,0 +1,250 @@ +# SQL Server 配置和初始化指南 + +## 📋 概述 + +本指南详细说明了SQL Server数据库在ETL系统中的配置、连接和示例数据初始化过程。 + +## ⚙️ 配置文件 + +### 1. 环境变量配置 (.env) + +```env +# SQL Server 数据库配置 +SQLSERVER_HOST=192.168.11.200 +SQLSERVER_PORT=1433 +SQLSERVER_USERNAME=sa +SQLSERVER_PASSWORD=sqlserver@7740 +SQLSERVER_DATABASE=test +``` + +### 2. 配置文件 (config.py) + +```python +SQLSERVER_CONFIG = { + "host": os.getenv("SQLSERVER_HOST", "localhost"), + "port": int(os.getenv("SQLSERVER_PORT", "1433")), + "username": os.getenv("SQLSERVER_USERNAME", "sa"), + "password": os.getenv("SQLSERVER_PASSWORD", "password"), + "database": os.getenv("SQLSERVER_DATABASE", "master") +} +``` + +## 🔧 技术实现 + +### 1. 连接驱动 + +- **驱动**: `pymssql` (已在 requirements.txt 中配置) +- **连接URL格式**: `mssql+pymssql://username:password@host:port/database` +- **SQLAlchemy引擎**: 支持连接池和自动重连 +- **URL编码**: 自动处理用户名和密码中的特殊字符(如@、#、&等) + +### 2. 连接管理 (database_manager.py) + +```python +def _build_connection_url(self, db_type, host, port, username, password, database=None, **kwargs): + # URL编码处理特殊字符 + encoded_username = quote_plus(username) + encoded_password = quote_plus(password) + + if db_type == "sqlserver": + db_part = f"/{database}" if database else "" + return f"mssql+pymssql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}" +``` + +### 3. 特殊字符处理 + +**问题**: 密码中包含特殊字符(如`@`符号)会导致URL解析错误 + +**解决方案**: 使用`urllib.parse.quote_plus()`对用户名和密码进行URL编码 + +**示例**: +- 原始密码: `sqlserver@7740` +- 编码后: `sqlserver%407740` +- 避免了URL解析时将`@`误认为用户名密码分隔符 + +## 📊 示例数据初始化 + +### 1. 数据表结构 + +#### customers 表 +```sql +CREATE TABLE customers ( + id INT IDENTITY(1,1) PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + email NVARCHAR(100), + phone NVARCHAR(20), + created_at DATETIME DEFAULT GETDATE() +) +``` + +#### orders 表 +```sql +CREATE TABLE orders ( + id INT IDENTITY(1,1) PRIMARY KEY, + customer_id INT, + product_name NVARCHAR(100) NOT NULL, + quantity INT DEFAULT 1, + price DECIMAL(10,2), + order_date DATETIME DEFAULT GETDATE(), + FOREIGN KEY (customer_id) REFERENCES customers(id) +) +``` + +### 2. 示例数据 + +#### customers 数据 +- 张三 (zhang.san@email.com, 13800138001) +- 李四 (li.si@email.com, 13800138002) +- 王五 (wang.wu@email.com, 13800138003) + +#### orders 数据 +- 笔记本电脑 (客户1, 数量1, 价格5999.99) +- 无线鼠标 (客户2, 数量2, 价格199.99) +- 机械键盘 (客户3, 数量1, 价格899.99) + +### 3. 初始化方法 (sample_data.py) + +```python +def init_sqlserver_sample_data(self): + """初始化SQL Server示例数据""" + try: + config = DatabaseConfig.get_config("sqlserver") + + # 创建连接 + connection_id = self.db_manager.create_connection( + db_type="sqlserver", + **config + ) + + # 创建表和插入数据 + # ... 详细实现见源码 + + return True + except Exception as e: + logger.error(f"SQL Server示例数据初始化失败: {str(e)}") + return False +``` + +## 🚀 使用方法 + +### 1. 自动初始化 + +启动API服务时自动初始化: + +```bash +python main.py +``` + +### 2. 手动测试连接 + +使用测试脚本: + +```bash +# URL编码测试 +python test_url_encoding.py + +# 完整连接测试 +python test_sqlserver_connection.py +``` + +### 3. API调用 + +```bash +# 获取连接列表 +curl http://localhost:8000/connections + +# 执行查询 +curl -X POST http://localhost:8000/query \ + -H "Content-Type: application/json" \ + -d '{"connection_id":"sqlserver_xxx", "query":"SELECT * FROM customers"}' +``` + +## 🔍 SQL Server 特性 + +### 1. 数据类型支持 +- **字符串**: NVARCHAR (支持Unicode) +- **数字**: INT, DECIMAL, FLOAT +- **日期**: DATETIME, DATE, TIME +- **自增**: IDENTITY(1,1) + +### 2. 连接特性 +- **端口**: 默认1433 +- **认证**: SQL Server认证和Windows认证 +- **数据库**: 支持多数据库实例 +- **编码**: UTF-8支持 + +## ⚠️ 注意事项 + +### 1. 密码特殊字符 +- 密码中包含`@`、`#`、`&`等特殊字符时会自动进行URL编码 +- 无需手动处理,系统会自动转换 + +### 2. 连接配置 +- 确保SQL Server服务已启动 +- 检查防火墙设置允许1433端口 +- 验证用户名密码正确性 +- 确认目标数据库存在 + +### 3. 权限要求 +- 用户需要有CREATE TABLE权限 +- 需要有INSERT、SELECT权限 +- 建议使用具有足够权限的数据库用户 + +## 🛠️ 故障排除 + +### 1. 连接失败 + +**错误**: `Unable to connect: Adaptive Server is unavailable or does not exist` + +**可能原因**: +- SQL Server服务未启动 +- 网络连接问题 +- 防火墙阻止连接 +- 主机地址或端口错误 + +**解决方案**: +1. 检查SQL Server服务状态 +2. 验证网络连接 +3. 检查防火墙设置 +4. 确认配置信息正确 + +### 2. 认证失败 + +**错误**: `Login failed for user` + +**解决方案**: +1. 检查用户名密码 +2. 确认SQL Server认证模式 +3. 验证用户权限 + +### 3. 数据库不存在 + +**错误**: `Cannot open database` + +**解决方案**: +1. 创建目标数据库 +2. 检查数据库名称拼写 +3. 验证用户访问权限 + +## 📁 相关文件 + +- `config.py` - 数据库配置定义 +- `database_manager.py` - 连接管理和URL构建 +- `sample_data.py` - 示例数据初始化 +- `test_sqlserver_connection.py` - 连接测试脚本 +- `test_url_encoding.py` - URL编码测试脚本 +- `.env` - 环境变量配置 +- `requirements.txt` - 依赖包配置 + +## 📈 总结 + +SQL Server已成功集成到数据库ETL系统中,支持: + +✅ **完整的连接管理** - 包含连接池和自动重连 +✅ **示例数据初始化** - 自动创建表和插入测试数据 +✅ **特殊字符处理** - 自动URL编码密码中的特殊字符 +✅ **错误处理和日志** - 详细的错误信息和调试日志 +✅ **测试工具** - 多个测试脚本验证功能 +✅ **API接口** - RESTful API支持查询和管理 + +系统现在可以稳定地处理包含特殊字符的SQL Server密码,并提供完整的数据库操作功能。 \ No newline at end of file diff --git a/__pycache__/api_routes.cpython-311.pyc b/__pycache__/api_routes.cpython-311.pyc new file mode 100644 index 0000000..938286b Binary files /dev/null and b/__pycache__/api_routes.cpython-311.pyc differ diff --git a/__pycache__/config.cpython-311.pyc b/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000..b238662 Binary files /dev/null and b/__pycache__/config.cpython-311.pyc differ diff --git a/__pycache__/database_manager.cpython-311.pyc b/__pycache__/database_manager.cpython-311.pyc new file mode 100644 index 0000000..7288caa Binary files /dev/null and b/__pycache__/database_manager.cpython-311.pyc differ diff --git a/__pycache__/main.cpython-311.pyc b/__pycache__/main.cpython-311.pyc new file mode 100644 index 0000000..52e81ce Binary files /dev/null and b/__pycache__/main.cpython-311.pyc differ diff --git a/__pycache__/models.cpython-311.pyc b/__pycache__/models.cpython-311.pyc new file mode 100644 index 0000000..e506d75 Binary files /dev/null and b/__pycache__/models.cpython-311.pyc differ diff --git a/__pycache__/sample_data.cpython-311.pyc b/__pycache__/sample_data.cpython-311.pyc new file mode 100644 index 0000000..d832317 Binary files /dev/null and b/__pycache__/sample_data.cpython-311.pyc differ diff --git a/__pycache__/table_routes.cpython-311.pyc b/__pycache__/table_routes.cpython-311.pyc new file mode 100644 index 0000000..a50f1ff Binary files /dev/null and b/__pycache__/table_routes.cpython-311.pyc differ diff --git a/api/v1/__init__.py b/api/v1/__init__.py new file mode 100644 index 0000000..e1143f5 --- /dev/null +++ b/api/v1/__init__.py @@ -0,0 +1,5 @@ +"""API v1 包初始化 + +此包包含 v1 版本的所有路由与依赖模块。 +""" + diff --git a/api/v1/__pycache__/__init__.cpython-311.pyc b/api/v1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..fb6da7b Binary files /dev/null and b/api/v1/__pycache__/__init__.cpython-311.pyc differ diff --git a/api/v1/deps.py b/api/v1/deps.py new file mode 100644 index 0000000..9c53502 --- /dev/null +++ b/api/v1/deps.py @@ -0,0 +1,16 @@ +"""API v1 依赖模块 + +提供路由可复用的依赖函数,例如获取数据库管理器。 +""" + +from typing import Generator +from database_manager import db_manager, DatabaseManager + +def get_db_manager() -> DatabaseManager: + """获取全局数据库管理器实例 + + Returns: + DatabaseManager: 全局的数据库管理器,用于创建/管理连接与执行SQL。 + """ + return db_manager + diff --git a/api/v1/routes/__init__.py b/api/v1/routes/__init__.py new file mode 100644 index 0000000..cceed8d --- /dev/null +++ b/api/v1/routes/__init__.py @@ -0,0 +1,5 @@ +"""v1 路由包初始化 + +用于组织不同领域的路由模块,例如数据库管理与表管理。 +""" + diff --git a/api/v1/routes/__pycache__/__init__.cpython-311.pyc b/api/v1/routes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..3ae8b96 Binary files /dev/null and b/api/v1/routes/__pycache__/__init__.cpython-311.pyc differ diff --git a/api/v1/routes/__pycache__/database.cpython-311.pyc b/api/v1/routes/__pycache__/database.cpython-311.pyc new file mode 100644 index 0000000..f2b4687 Binary files /dev/null and b/api/v1/routes/__pycache__/database.cpython-311.pyc differ diff --git a/api/v1/routes/__pycache__/tables.cpython-311.pyc b/api/v1/routes/__pycache__/tables.cpython-311.pyc new file mode 100644 index 0000000..ae2f5de Binary files /dev/null and b/api/v1/routes/__pycache__/tables.cpython-311.pyc differ diff --git a/api/v1/routes/database.py b/api/v1/routes/database.py new file mode 100644 index 0000000..f5cbdf5 --- /dev/null +++ b/api/v1/routes/database.py @@ -0,0 +1,357 @@ +from fastapi import APIRouter, HTTPException, status +from typing import List +import logging +from pydantic import BaseModel +from database_manager import db_manager +from schemas import ( + DatabaseConnection, QueryRequest, ExecuteRequest, TableDataRequest, + InsertDataRequest, UpdateDataRequest, DeleteDataRequest, + CreateTableRequest, AlterTableRequest, CommentRequest, + ApiResponse, ConnectionResponse, DatabaseInfo, TableInfo, QueryResult +) + +logger = logging.getLogger(__name__) + +# 创建路由器 +router = APIRouter() + +# 数据库连接管理接口 +@router.post("/connections", response_model=ApiResponse, summary="创建数据库连接") +async def create_connection(connection: DatabaseConnection): + """创建数据库连接""" + try: + # 准备额外的连接参数 + kwargs = {} + if connection.mode: + kwargs['mode'] = connection.mode + if connection.threaded is not None: + kwargs['threaded'] = connection.threaded + if connection.extra_params: + kwargs.update(connection.extra_params) + + connection_id = db_manager.create_connection( + db_type=connection.db_type, + host=connection.host, + port=connection.port, + username=connection.username, + password=connection.password, + database=connection.database, + **kwargs + ) + + response_data = ConnectionResponse( + connection_id=connection_id, + db_type=connection.db_type, + host=connection.host, + port=connection.port, + database=connection.database + ) + + return ApiResponse( + success=True, + message="数据库连接创建成功", + data=response_data.dict() + ) + except Exception as e: + logger.error(f"创建连接失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"创建连接失败: {str(e)}" + ) + +@router.post("/connections/test", response_model=ApiResponse, summary="测试是否能连通数据库") +async def test_connection(connection: DatabaseConnection): + """测试是否能连通数据库""" + try: + kwargs = {} + if connection.mode: + kwargs['mode'] = connection.mode + if connection.threaded is not None: + kwargs['threaded'] = connection.threaded + if connection.extra_params: + kwargs.update(connection.extra_params) + + result = db_manager.test_connection( + db_type=connection.db_type, + host=connection.host, + port=connection.port, + username=connection.username, + password=connection.password, + database=connection.database, + **kwargs + ) + if not result.get("ok"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"连接测试失败: {result.get('error')}" + ) + return ApiResponse( + success=True, + message="数据库连通性测试成功", + data=result + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"连接测试失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"连接测试失败: {str(e)}" + ) + +@router.get("/connections", response_model=ApiResponse, summary="获取所有连接") +async def list_connections(): + """获取所有活动连接""" + try: + connections = db_manager.list_connections() + return ApiResponse( + success=True, + message="获取连接列表成功", + data=connections + ) + except Exception as e: + logger.error(f"获取连接列表失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取连接列表失败: {str(e)}" + ) + +class CloseConnectionRequest(BaseModel): + """关闭连接请求体模型 + + Attributes: + connection_id: 需要关闭的连接ID + """ + connection_id: str + +@router.post("/connections/close", response_model=ApiResponse, summary="关闭数据库连接(POST,JSON传参)") +async def close_connection(request: CloseConnectionRequest): + """关闭数据库连接(POST,使用JSON Body传参) + + 请求示例: + {"connection_id": "mysql_localhost_3306_test_db"} + """ + try: + db_manager.close_connection(request.connection_id) + return ApiResponse( + success=True, + message="数据库连接已关闭" + ) + except Exception as e: + logger.error(f"关闭连接失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"关闭连接失败: {str(e)}" + ) + +# 数据库信息接口 +@router.get("/databases/info", response_model=ApiResponse, summary="获取数据库信息(使用query参数)") +async def get_database_info(connection_id: str): + """获取数据库信息(通过URL query参数) + + Params: + - connection_id: 通过URL query传入,例如 `?connection_id=xxx` + """ + try: + db_info = db_manager.get_database_info(connection_id) + return ApiResponse( + success=True, + message="获取数据库信息成功", + data=db_info + ) + except Exception as e: + logger.error(f"获取数据库信息失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"获取数据库信息失败: {str(e)}" + ) + +@router.get("/databases/tables/info", response_model=ApiResponse, summary="获取表信息(使用query参数)") +async def get_table_info(connection_id: str, table_name: str): + """获取表信息(通过URL query参数) + + Params: + - connection_id: 通过URL query传入,例如 `?connection_id=xxx` + - table_name: 通过URL query传入,例如 `?table_name=users` + """ + try: + table_info = db_manager.get_table_info(connection_id, table_name) + return ApiResponse( + success=True, + message="获取表信息成功", + data=table_info + ) + except Exception as e: + logger.error(f"获取表信息失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"获取表信息失败: {str(e)}" + ) + +# SQL执行接口 +@router.post("/query", response_model=ApiResponse, summary="执行查询SQL") +async def execute_query(request: QueryRequest): + """执行查询SQL""" + try: + result = db_manager.execute_query( + connection_id=request.connection_id, + sql=request.sql, + params=request.params + ) + return ApiResponse( + success=True, + message="查询执行成功", + data=result + ) + except Exception as e: + logger.error(f"查询执行失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"查询执行失败: {str(e)}" + ) + +@router.post("/execute", response_model=ApiResponse, summary="执行非查询SQL") +async def execute_non_query(request: ExecuteRequest): + """执行非查询SQL(INSERT, UPDATE, DELETE)""" + try: + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=request.sql, + params=request.params + ) + return ApiResponse( + success=True, + message="SQL执行成功", + data={"affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"SQL执行失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"SQL执行失败: {str(e)}" + ) + +# 表数据CRUD接口 +@router.post("/tables/data/select", response_model=ApiResponse, summary="查询表数据") +async def select_table_data(request: TableDataRequest): + """查询表数据""" + try: + # 构建查询SQL + sql = f"SELECT * FROM {request.table_name}" + + if request.where_clause: + sql += f" WHERE {request.where_clause}" + + if request.order_by: + sql += f" ORDER BY {request.order_by}" + + # 添加分页 + offset = (request.page - 1) * request.page_size + sql += f" LIMIT {request.page_size} OFFSET {offset}" + + # 执行查询 + data = db_manager.execute_query(request.connection_id, sql) + + # 获取总数 + count_sql = f"SELECT COUNT(*) as total FROM {request.table_name}" + if request.where_clause: + count_sql += f" WHERE {request.where_clause}" + + count_result = db_manager.execute_query(request.connection_id, count_sql) + total = count_result[0]['total'] if count_result else 0 + + result = QueryResult( + data=data, + total=total, + page=request.page, + page_size=request.page_size + ) + + return ApiResponse( + success=True, + message="查询表数据成功", + data=result.dict() + ) + except Exception as e: + logger.error(f"查询表数据失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"查询表数据失败: {str(e)}" + ) + +@router.post("/tables/data/insert", response_model=ApiResponse, summary="插入表数据") +async def insert_table_data(request: InsertDataRequest): + """插入表数据""" + try: + # 构建插入SQL + columns = list(request.data.keys()) + placeholders = [f":{col}" for col in columns] + + sql = f"INSERT INTO {request.table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)})" + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql, + params=request.data + ) + + return ApiResponse( + success=True, + message="数据插入成功", + data={"affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"插入数据失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"插入数据失败: {str(e)}" + ) + +@router.post("/tables/data/update", response_model=ApiResponse, summary="更新表数据(改为POST)") +async def update_table_data(request: UpdateDataRequest): + """更新表数据(HTTP方法改为POST)""" + try: + # 构建更新SQL + set_clauses = [f"{col} = :{col}" for col in request.data.keys()] + sql = f"UPDATE {request.table_name} SET {', '.join(set_clauses)} WHERE {request.where_clause}" + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql, + params=request.data + ) + + return ApiResponse( + success=True, + message="数据更新成功", + data={"affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"更新数据失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"更新数据失败: {str(e)}" + ) + +@router.post("/tables/data/delete", response_model=ApiResponse, summary="删除表数据(改为POST)") +async def delete_table_data(request: DeleteDataRequest): + """删除表数据(HTTP方法改为POST)""" + try: + sql = f"DELETE FROM {request.table_name} WHERE {request.where_clause}" + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql + ) + + return ApiResponse( + success=True, + message="数据删除成功", + data={"affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"删除数据失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"删除数据失败: {str(e)}" + ) diff --git a/api/v1/routes/tables.py b/api/v1/routes/tables.py new file mode 100644 index 0000000..a864339 --- /dev/null +++ b/api/v1/routes/tables.py @@ -0,0 +1,246 @@ +from fastapi import APIRouter, HTTPException, status +import logging +from pydantic import BaseModel + +from database_manager import db_manager +from schemas import ( + CreateTableRequest, AlterTableRequest, CommentRequest, ApiResponse +) + +logger = logging.getLogger(__name__) + +# 创建路由器 +router = APIRouter() + +# 表结构管理接口 +@router.post("/tables/create", response_model=ApiResponse, summary="创建表") +async def create_table(request: CreateTableRequest): + """创建表""" + try: + # 构建创建表的SQL + column_definitions = [] + for col in request.columns: + col_def = f"{col['name']} {col['type']}" + if col.get('not_null', False): + col_def += " NOT NULL" + if col.get('primary_key', False): + col_def += " PRIMARY KEY" + if col.get('default'): + col_def += f" DEFAULT {col['default']}" + if col.get('comment'): + col_def += f" COMMENT '{col['comment']}'" + column_definitions.append(col_def) + + sql = f"CREATE TABLE {request.table_name} ({', '.join(column_definitions)})" + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql + ) + + return ApiResponse( + success=True, + message="表创建成功", + data={"table_name": request.table_name, "affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"创建表失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"创建表失败: {str(e)}" + ) + +class DropTableRequest(BaseModel): + """删除表请求体模型 + + Attributes: + connection_id: 连接ID + table_name: 需要删除的表名 + """ + connection_id: str + table_name: str + +@router.post("/tables/delete", response_model=ApiResponse, summary="删除表(POST,JSON传参)") +async def drop_table(request: DropTableRequest): + """删除表(POST,使用JSON Body传参) + + 请求示例: + {"connection_id": "mysql_localhost_3306_test_db", "table_name": "users"} + """ + try: + sql = f"DROP TABLE {request.table_name}" + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql + ) + + return ApiResponse( + success=True, + message="表删除成功", + data={"table_name": request.table_name, "affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"删除表失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"删除表失败: {str(e)}" + ) + +@router.post("/tables/alter", response_model=ApiResponse, summary="修改表结构(改为POST)") +async def alter_table(request: AlterTableRequest): + """修改表结构(HTTP方法改为POST)""" + try: + sql = "" + + if request.operation.upper() == "ADD": + # 添加列 + col_def = request.column_definition + column_sql = f"{col_def['name']} {col_def['type']}" + if col_def.get('not_null', False): + column_sql += " NOT NULL" + if col_def.get('default'): + column_sql += f" DEFAULT {col_def['default']}" + sql = f"ALTER TABLE {request.table_name} ADD COLUMN {column_sql}" + + elif request.operation.upper() == "DROP": + # 删除列 + column_name = request.column_definition['name'] + sql = f"ALTER TABLE {request.table_name} DROP COLUMN {column_name}" + + elif request.operation.upper() == "MODIFY": + # 修改列 + col_def = request.column_definition + column_sql = f"{col_def['name']} {col_def['type']}" + if col_def.get('not_null', False): + column_sql += " NOT NULL" + sql = f"ALTER TABLE {request.table_name} MODIFY COLUMN {column_sql}" + + else: + raise ValueError(f"不支持的操作类型: {request.operation}") + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql + ) + + return ApiResponse( + success=True, + message="表结构修改成功", + data={"table_name": request.table_name, "operation": request.operation, "affected_rows": affected_rows} + ) + except Exception as e: + logger.error(f"修改表结构失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"修改表结构失败: {str(e)}" + ) + +# 备注管理接口 +@router.post("/tables/comment", response_model=ApiResponse, summary="修改表或字段备注(改为POST)") +async def update_comment(request: CommentRequest): + """修改表或字段备注(HTTP方法改为POST)""" + try: + if request.column_name: + # 修改字段备注 + sql = f"ALTER TABLE {request.table_name} MODIFY COLUMN {request.column_name} COMMENT '{request.comment}'" + else: + # 修改表备注 + sql = f"ALTER TABLE {request.table_name} COMMENT '{request.comment}'" + + affected_rows = db_manager.execute_non_query( + connection_id=request.connection_id, + sql=sql + ) + + return ApiResponse( + success=True, + message="备注修改成功", + data={ + "table_name": request.table_name, + "column_name": request.column_name, + "comment": request.comment, + "affected_rows": affected_rows + } + ) + except Exception as e: + logger.error(f"修改备注失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"修改备注失败: {str(e)}" + ) + +@router.get("/tables/columns", response_model=ApiResponse, summary="获取表的所有字段信息(使用query参数)") +async def get_table_columns(connection_id: str, table_name: str): + """获取表的所有字段信息(通过URL query参数) + + Params: + - connection_id: 通过URL query传入,例如 `?connection_id=xxx` + - table_name: 通过URL query传入,例如 `?table_name=users` + """ + try: + columns = db_manager.get_table_columns(connection_id, table_name) + + return ApiResponse( + success=True, + message="获取字段信息成功", + data={ + "table_name": table_name, + "columns": columns + } + ) + except Exception as e: + logger.error(f"获取字段信息失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"获取字段信息失败: {str(e)}" + ) + +@router.get("/tables", response_model=ApiResponse, summary="获取数据库中的所有表(使用query参数)") +async def get_all_tables(connection_id: str): + """获取数据库中的所有表(通过URL query参数) + + Params: + - connection_id: 通过URL query传入,例如 `?connection_id=xxx` + """ + try: + db_info = db_manager.get_database_info(connection_id) + return ApiResponse( + success=True, + message="获取表列表成功", + data={ + "database_name": db_info["database_name"], + "tables": db_info["tables"], + "table_count": db_info["table_count"] + } + ) + except Exception as e: + logger.error(f"获取表列表失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"获取表列表失败: {str(e)}" + ) + +@router.get("/tables/details", response_model=ApiResponse, summary="获取所有表及其备注信息(使用query参数)") +async def get_tables_with_comments(connection_id: str): + """获取所有表及其备注信息(通过URL query参数) + + Params: + - connection_id: 通过URL query传入,例如 `?connection_id=xxx` + """ + try: + tables = db_manager.get_tables_with_comments(connection_id) + return ApiResponse( + success=True, + message="获取表及备注信息成功", + data={ + "tables": tables, + "table_count": len(tables) + } + ) + except Exception as e: + logger.error(f"获取表备注信息失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"获取表备注信息失败: {str(e)}" + ) diff --git a/config.py b/config.py new file mode 100644 index 0000000..70925bb --- /dev/null +++ b/config.py @@ -0,0 +1,61 @@ +import os +from typing import Dict, Any +from dotenv import load_dotenv + +# 加载环境变量文件 +load_dotenv() + +class DatabaseConfig: + """数据库配置类""" + + # MySQL配置 + MYSQL_CONFIG = { + "host": os.getenv("MYSQL_HOST", "localhost"), + "port": int(os.getenv("MYSQL_PORT", "3306")), + "username": os.getenv("MYSQL_USERNAME", "root"), + "password": os.getenv("MYSQL_PASSWORD", "password"), + "database": os.getenv("MYSQL_DATABASE", "test_db") + } + + # Oracle配置 + ORACLE_CONFIG = { + "host": os.getenv("ORACLE_HOST", "192.168.13.27"), + "port": int(os.getenv("ORACLE_PORT", "1521")), + "username": os.getenv("ORACLE_USERNAME", "bizuser"), + "password": os.getenv("ORACLE_PASSWORD", "MySecurePass123"), + "service_name": os.getenv("ORACLE_SERVICE_NAME", "ORCLPDB1") + } + + # SQL Server配置 + SQLSERVER_CONFIG = { + "host": os.getenv("SQLSERVER_HOST", "localhost"), + "port": int(os.getenv("SQLSERVER_PORT", "1433")), + "username": os.getenv("SQLSERVER_USERNAME", "sa"), + "password": os.getenv("SQLSERVER_PASSWORD", "password"), + "database": os.getenv("SQLSERVER_DATABASE", "master") + } + + # PostgreSQL配置 + POSTGRESQL_CONFIG = { + "host": os.getenv("POSTGRESQL_HOST", "localhost"), + "port": int(os.getenv("POSTGRESQL_PORT", "5432")), + "username": os.getenv("POSTGRESQL_USERNAME", "postgres"), + "password": os.getenv("POSTGRESQL_PASSWORD", "password"), + "database": os.getenv("POSTGRESQL_DATABASE", "postgres") + } + + @classmethod + def get_config(cls, db_type: str) -> Dict[str, Any]: + """根据数据库类型获取配置""" + config_map = { + "mysql": cls.MYSQL_CONFIG, + "oracle": cls.ORACLE_CONFIG, + "sqlserver": cls.SQLSERVER_CONFIG, + "postgresql": cls.POSTGRESQL_CONFIG + } + return config_map.get(db_type.lower(), {}) + + @classmethod + def is_sample_data_enabled(cls) -> bool: + """检查是否启用示例数据初始化""" + return os.getenv("ENABLE_SAMPLE_DATA", "true").lower() == "true" \ No newline at end of file diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000..535df16 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,16 @@ +"""通用工具模块 + +提供项目可能用到的工具函数。 +""" + +def to_safe_str(value) -> str: + """安全地将任意值转换为字符串 + + Args: + value: 任意值 + + Returns: + str: 字符串表示,None 将转换为 "" + """ + return "" if value is None else str(value) + diff --git a/database_manager.py b/database_manager.py new file mode 100644 index 0000000..5d9093a --- /dev/null +++ b/database_manager.py @@ -0,0 +1,508 @@ +from sqlalchemy import create_engine, text, MetaData, Table, Column, inspect +from sqlalchemy.orm import sessionmaker +from sqlalchemy.exc import SQLAlchemyError +from typing import Dict, List, Any, Optional +import logging +import oracledb +from urllib.parse import quote_plus + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class DatabaseManager: + """数据库管理类,支持多种数据库类型的连接和操作""" + + def __init__(self): + self.engines = {} + self.sessions = {} + + def create_connection(self, db_type: str, host: str, port: int, + username: str, password: str, database: str = None, **kwargs) -> str: + """创建数据库连接 + + Args: + db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql) + host: 数据库主机地址 + port: 数据库端口 + username: 用户名 + password: 密码 + database: 数据库名称 + + Returns: + connection_id: 连接ID + """ + try: + connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs) + connection_id = f"{db_type}_{host}_{port}_{database or 'default'}" + + # 配置引擎参数 + engine_kwargs = {"echo": False, "pool_pre_ping": True} + if db_type.lower() == 'oracle': + # Oracle连接池配置 + engine_kwargs.update({ + "pool_size": 5, + "max_overflow": 10, + "pool_timeout": 30, + "pool_recycle": 3600, + "pool_reset_on_return": "commit" + }) + + engine = create_engine(connection_url, **engine_kwargs) + + # 测试连接 + with engine.connect() as conn: + if db_type.lower() == 'oracle': + conn.execute(text("SELECT 1 FROM DUAL")) + else: + conn.execute(text("SELECT 1")) + + self.engines[connection_id] = engine + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + self.sessions[connection_id] = SessionLocal + + logger.info(f"成功创建数据库连接: {connection_id}") + return connection_id + + except Exception as e: + logger.error(f"创建数据库连接失败: {str(e)}") + raise Exception(f"数据库连接失败: {str(e)}") + + def test_connection(self, db_type: str, host: str, port: int, + username: str, password: str, database: str = None, **kwargs) -> Dict[str, Any]: + """测试数据库是否可连通 + + Args: + db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql) + host: 数据库主机地址 + port: 数据库端口 + username: 用户名 + password: 密码 + database: 数据库名称 + + Returns: + Dict[str, Any]: 测试结果信息,包含连接是否成功与版本信息 + """ + try: + # 构建连接URL(内部会对Oracle进行预测试) + connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs) + engine = create_engine(connection_url, echo=False, pool_pre_ping=True) + + server_version = None + with engine.connect() as conn: + # 基本连通性测试 + if db_type.lower() == 'oracle': + conn.execute(text("SELECT 1 FROM DUAL")) + version_sql = "SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1" + server_version = conn.execute(text(version_sql)).scalar() + elif db_type.lower() == 'mysql': + conn.execute(text("SELECT 1")) + server_version = conn.execute(text("SELECT VERSION()")).scalar() + elif db_type.lower() == 'postgresql': + conn.execute(text("SELECT 1")) + server_version = conn.execute(text("SELECT version()")).scalar() + elif db_type.lower() == 'sqlserver': + conn.execute(text("SELECT 1")) + server_version = conn.execute(text("SELECT @@VERSION")).scalar() + else: + raise ValueError(f"不支持的数据库类型: {db_type}") + + # 释放临时引擎 + engine.dispose() + + return { + "ok": True, + "db_type": db_type, + "connection_url": str(connection_url), + "server_version": server_version + } + except Exception as e: + logger.error(f"连接测试失败: {str(e)}") + return { + "ok": False, + "db_type": db_type, + "error": str(e) + } + + def _build_connection_url(self, db_type: str, host: str, port: int, + username: str, password: str, database: str = None, **kwargs) -> str: + """构建数据库连接URL""" + # 对用户名和密码进行URL编码,防止特殊字符导致解析错误 + encoded_username = quote_plus(username) + encoded_password = quote_plus(password) + + if db_type.lower() == 'mysql': + db_part = f"/{database}" if database else "" + return f"mysql+pymysql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}?charset=utf8mb4" + elif db_type.lower() == 'oracle': + # Oracle连接格式 - 根据oracledb文档优化 + service_name = database or 'XE' + + # 先测试Oracle连接以确保参数正确 + try: + self._test_oracle_connection( + host=host, + port=port, + username=username, + password=password, + service_name=service_name, + **kwargs + ) + except Exception as e: + logger.error(f"Oracle连接预测试失败: {str(e)}") + raise e + + # 使用简化的SQLAlchemy URL格式 + # 根据oracledb文档,SQLAlchemy会自动处理连接参数 + base_url = f"oracle+oracledb://{encoded_username}:{encoded_password}@{host}:{port}/?service_name={service_name}" + + return base_url + elif db_type.lower() == 'sqlserver': + db_part = f"/{database}" if database else "" + return f"mssql+pymssql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}" + elif db_type.lower() == 'postgresql': + db_part = f"/{database}" if database else "/postgres" + return f"postgresql+psycopg2://{encoded_username}:{encoded_password}@{host}:{port}{db_part}" + else: + raise ValueError(f"不支持的数据库类型: {db_type}") + + def _test_oracle_connection(self, host: str, port: int, username: str, password: str, + service_name: str = None, **kwargs): + """测试Oracle直接连接""" + service_name = service_name or 'XE' + + # 尝试多种连接方式 + connection_methods = [ + # 方式1: Easy Connect字符串 + { + 'name': 'Easy Connect', + 'params': { + 'user': username, + 'password': password, + 'dsn': f"{host}:{port}/{service_name}" + } + }, + # 方式2: 分离参数 + { + 'name': 'Separate Parameters', + 'params': { + 'user': username, + 'password': password, + 'host': host, + 'port': port, + 'service_name': service_name + } + }, + # 方式3: 使用SID + { + 'name': 'SID Connection', + 'params': { + 'user': username, + 'password': password, + 'dsn': oracledb.makedsn(host, port, sid=service_name) + } + } + ] + + last_error = None + + for method in connection_methods: + try: + logger.info(f"尝试Oracle连接方式: {method['name']}") + + # 尝试连接 + connection = oracledb.connect(**method['params']) + + # 测试查询 + cursor = connection.cursor() + cursor.execute("SELECT 1 FROM DUAL") + result = cursor.fetchone() + + # 获取数据库版本信息 + cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1") + version = cursor.fetchone() + + cursor.close() + connection.close() + + logger.info(f"Oracle连接成功 ({method['name']}): 查询结果={result}, 版本={version[0] if version else 'Unknown'}") + return # 连接成功,退出函数 + + except Exception as e: + last_error = e + logger.warning(f"Oracle连接方式 '{method['name']}' 失败: {str(e)}") + continue + + # 所有连接方式都失败 + error_msg = f"所有Oracle连接方式都失败。最后一个错误: {str(last_error)}" + logger.error(error_msg) + raise Exception(error_msg) + + def get_engine(self, connection_id: str): + """获取数据库引擎""" + if connection_id not in self.engines: + raise ValueError(f"连接ID不存在: {connection_id}") + return self.engines[connection_id] + + def get_session(self, connection_id: str): + """获取数据库会话""" + if connection_id not in self.sessions: + raise ValueError(f"连接ID不存在: {connection_id}") + return self.sessions[connection_id]() + + def execute_query(self, connection_id: str, sql: str, params: Dict = None) -> List[Dict]: + """执行查询SQL""" + try: + engine = self.get_engine(connection_id) + with engine.connect() as conn: + result = conn.execute(text(sql), params or {}) + columns = result.keys() + rows = result.fetchall() + return [dict(zip(columns, row)) for row in rows] + except Exception as e: + logger.error(f"执行查询失败: {str(e)}") + raise Exception(f"查询执行失败: {str(e)}") + + def execute_non_query(self, connection_id: str, sql: str, params: Dict = None) -> int: + """执行非查询SQL(INSERT, UPDATE, DELETE)""" + try: + engine = self.get_engine(connection_id) + with engine.connect() as conn: + with conn.begin(): + result = conn.execute(text(sql), params or {}) + return result.rowcount + except Exception as e: + logger.error(f"执行非查询失败: {str(e)}") + raise Exception(f"非查询执行失败: {str(e)}") + + def get_database_info(self, connection_id: str) -> Dict: + """获取数据库信息""" + try: + engine = self.get_engine(connection_id) + inspector = inspect(engine) + + # 获取数据库名称 + with engine.connect() as conn: + db_name_result = conn.execute(text("SELECT DATABASE()" if 'mysql' in str(engine.url) + else "SELECT CURRENT_DATABASE()" if 'postgresql' in str(engine.url) + else "SELECT DB_NAME()" if 'mssql' in str(engine.url) + else "SELECT SYS_CONTEXT('USERENV', 'DB_NAME') FROM DUAL")) + db_name = db_name_result.scalar() + + # 获取表列表 + tables = inspector.get_table_names() + + return { + "database_name": db_name, + "tables": tables, + "table_count": len(tables) + } + except Exception as e: + logger.error(f"获取数据库信息失败: {str(e)}") + raise Exception(f"获取数据库信息失败: {str(e)}") + + def get_tables_with_comments(self, connection_id: str) -> List[Dict[str, Any]]: + """获取数据库中所有表及其备注信息""" + try: + engine = self.get_engine(connection_id) + url_str = str(engine.url) + with engine.connect() as conn: + if 'mysql' in url_str: + sql = """ + SELECT + TABLE_NAME AS table_name, + TABLE_COMMENT AS table_comment + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = DATABASE() + ORDER BY TABLE_NAME + """ + return [dict(row) for row in conn.execute(text(sql)).mappings().all()] + elif 'postgresql' in url_str: + # 获取当前schema + schema = conn.execute(text("SELECT current_schema()")).scalar() or 'public' + sql = """ + SELECT + c.relname AS table_name, + obj_description(c.oid) AS table_comment + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = :schema + AND c.relkind IN ('r','p') + ORDER BY c.relname + """ + return [dict(row) for row in conn.execute(text(sql), {"schema": schema}).mappings().all()] + elif 'mssql' in url_str: + sql = """ + SELECT + t.name AS table_name, + CAST(ep.value AS NVARCHAR(MAX)) AS table_comment + FROM sys.tables t + LEFT JOIN sys.extended_properties ep + ON ep.major_id = t.object_id + AND ep.minor_id = 0 + AND ep.name = 'MS_Description' + ORDER BY t.name + """ + return [dict(row) for row in conn.execute(text(sql)).mappings().all()] + elif 'oracle' in url_str: + sql = """ + SELECT + table_name AS table_name, + comments AS table_comment + FROM user_tab_comments + ORDER BY table_name + """ + return [dict(row) for row in conn.execute(text(sql)).mappings().all()] + else: + raise ValueError("不支持的数据库类型") + except Exception as e: + logger.error(f"获取表备注信息失败: {str(e)}") + raise Exception(f"获取表备注信息失败: {str(e)}") + + def get_table_info(self, connection_id: str, table_name: str) -> Dict: + """获取表信息 + + 序列化说明: + - SQLAlchemy 的列类型对象(如 `VARCHAR`、`NUMBER` 等)不可直接JSON序列化。 + - 本方法将列的 `type` 字段统一转换为字符串表示(例如 `VARCHAR(255)`)。 + - 其他由 Inspector 返回的结构(主键、外键、索引)保持原始可序列化格式。 + """ + try: + engine = self.get_engine(connection_id) + inspector = inspect(engine) + + # 获取列信息 + raw_columns = inspector.get_columns(table_name) + columns = [] + for col in raw_columns: + # 创建可序列化的列字典 + serializable = dict(col) + col_type = serializable.get('type') + # 将SQLAlchemy类型对象转换为字符串 + if col_type is not None: + try: + serializable['type'] = str(col_type) + except Exception: + # 兜底:使用类型名 + serializable['type'] = type(col_type).__name__ + columns.append(serializable) + + # 获取主键 + primary_keys = inspector.get_pk_constraint(table_name) + + # 获取外键 + foreign_keys = inspector.get_foreign_keys(table_name) + + # 获取索引 + indexes = inspector.get_indexes(table_name) + + return { + "table_name": table_name, + "columns": columns, + "primary_keys": primary_keys, + "foreign_keys": foreign_keys, + "indexes": indexes + } + except Exception as e: + logger.error(f"获取表信息失败: {str(e)}") + raise Exception(f"获取表信息失败: {str(e)}") + + def get_table_columns(self, connection_id: str, table_name: str) -> List[Dict[str, Any]]: + """获取指定表的字段名、类型、备注等信息(兼容多数据库)""" + try: + engine = self.get_engine(connection_id) + url_str = str(engine.url) + with engine.connect() as conn: + if 'mysql' in url_str: + sql = """ + SELECT + COLUMN_NAME AS column_name, + DATA_TYPE AS data_type, + IS_NULLABLE AS is_nullable, + COLUMN_DEFAULT AS column_default, + COLUMN_COMMENT AS column_comment, + CHARACTER_MAXIMUM_LENGTH AS max_length, + NUMERIC_PRECISION AS numeric_precision, + NUMERIC_SCALE AS numeric_scale + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = :table_name + ORDER BY ORDINAL_POSITION + """ + return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] + elif 'postgresql' in url_str: + sql = """ + SELECT + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + NOT a.attnotnull AS is_nullable, + pg_get_expr(d.adbin, d.adrelid) AS column_default, + col_description(a.attrelid, a.attnum) AS column_comment + FROM pg_attribute a + JOIN pg_class c ON a.attrelid = c.oid + LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum + WHERE c.relname = :table_name + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + """ + return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] + elif 'mssql' in url_str: + sql = """ + SELECT + c.name AS column_name, + t.name AS data_type, + CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable, + OBJECT_DEFINITION(c.default_object_id) AS column_default, + CAST(ep.value AS NVARCHAR(MAX)) AS column_comment, + COLUMNPROPERTY(c.object_id, c.name, 'charmaxlen') AS max_length + FROM sys.columns c + LEFT JOIN sys.types t ON c.user_type_id = t.user_type_id + LEFT JOIN sys.extended_properties ep + ON ep.major_id = c.object_id + AND ep.minor_id = c.column_id + AND ep.name = 'MS_Description' + WHERE c.object_id = OBJECT_ID(:table_name) + ORDER BY c.column_id + """ + return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] + elif 'oracle' in url_str: + sql = """ + SELECT + cols.column_name AS column_name, + cols.data_type AS data_type, + cols.nullable AS is_nullable, + cols.data_default AS column_default, + comm.comments AS column_comment, + cols.data_length AS max_length + FROM user_tab_columns cols + LEFT JOIN user_col_comments comm + ON comm.table_name = cols.table_name + AND comm.column_name = cols.column_name + WHERE cols.table_name = UPPER(:table_name) + ORDER BY cols.column_id + """ + return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] + else: + raise ValueError("不支持的数据库类型") + except Exception as e: + logger.error(f"获取字段信息失败: {str(e)}") + raise Exception(f"获取字段信息失败: {str(e)}") + + def close_connection(self, connection_id: str): + """关闭数据库连接""" + try: + if connection_id in self.engines: + self.engines[connection_id].dispose() + del self.engines[connection_id] + del self.sessions[connection_id] + logger.info(f"已关闭数据库连接: {connection_id}") + except Exception as e: + logger.error(f"关闭连接失败: {str(e)}") + + def list_connections(self) -> List[str]: + """列出所有活动连接""" + return list(self.engines.keys()) + +# 全局数据库管理器实例 +db_manager = DatabaseManager() diff --git a/docs/doc.md b/docs/doc.md new file mode 100644 index 0000000..323c4ed --- /dev/null +++ b/docs/doc.md @@ -0,0 +1,182 @@ +# 接口文档(主要接口) + +以下文档覆盖 README 中“主要接口包括”的四项能力:测试连通、获取数据库信息、获取所有表及备注信息、获取表字段及备注信息。所有接口均遵循统一返回结构: + +```json +{ + "success": true, + "message": "描述信息", + "data": { "具体数据" }, + "error": null +} +``` + +## 1. 测试能否连通数据库 +- URL: `/api/v1/connections/test` +- Method: `POST` +- Content-Type: `application/json` +- 请求体参数(DatabaseConnection): + - db_type: string(mysql|oracle|sqlserver|postgresql) + - host: string + - port: number + - username: string + - password: string + - database: string,可选 + - mode: string,可选(Oracle专用) + - threaded: boolean,可选(Oracle专用) + - extra_params: object,可选(附加参数) +- 请求示例: + +```json +{ + "db_type": "mysql", + "host": "127.0.0.1", + "port": 3306, + "username": "root", + "password": "pass", + "database": "test_db" +} +``` + +- 返回示例(成功): + +```json +{ + "success": true, + "message": "数据库连通性测试成功", + "data": { + "ok": true, + "db_type": "mysql", + "connection_url": "mysql+pymysql://root:***@127.0.0.1:3306/test_db?charset=utf8mb4", + "server_version": "8.0.36" + }, + "error": null +} +``` + +- 返回示例(失败): + +```json +{ + "success": false, + "message": "连接测试失败: 认证失败", + "data": null, + "error": "认证失败" +} +``` + +## 2. 获取数据库信息 +- URL: `/api/v1/databases/info` +- Method: `GET` +- Query 参数: + - connection_id: string(创建连接时生成的连接ID) +- 返回字段: + - database_name: string + - tables: string[](所有表名) + - table_count: number(表数量) +- 返回示例: + +```json +{ + "success": true, + "message": "获取数据库信息成功", + "data": { + "database_name": "test_db", + "tables": ["users", "orders", "products"], + "table_count": 3 + }, + "error": null +} +``` + +## 3. 获取数据库中所有的表和表备注信息 +- URL: `/api/v1/tables/details` +- Method: `GET` +- Query 参数: + - connection_id: string +- 返回字段: + - tables: 数组,每项包含 + - table_name: string + - table_comment: string|null + - table_count: number +- 返回示例: + +```json +{ + "success": true, + "message": "获取表及备注信息成功", + "data": { + "tables": [ + { "table_name": "users", "table_comment": "用户信息表" }, + { "table_name": "orders", "table_comment": "订单表" }, + { "table_name": "products", "table_comment": "" } + ], + "table_count": 3 + }, + "error": null +} +``` + +说明:不同数据库的备注来源 +- MySQL: INFORMATION_SCHEMA.TABLES.TABLE_COMMENT +- PostgreSQL: pg_class/obj_description +- SQL Server: sys.extended_properties('MS_Description') +- Oracle: user_tab_comments + +## 4. 获取数据表中字段名和类型以及备注信息 +- URL: `/api/v1/tables/columns` +- Method: `GET` +- Query 参数: + - connection_id: string + - table_name: string +- 返回字段: + - table_name: string + - columns: 数组,每项包含 + - column_name: string + - data_type: string + - is_nullable: string|boolean(不同库返回值格式略有差异) + - column_default: string|null + - column_comment: string|null + - max_length: number|null + - numeric_precision: number|null(部分库返回) + - numeric_scale: number|null(部分库返回) +- 返回示例(MySQL样例): + +```json +{ + "success": true, + "message": "获取字段信息成功", + "data": { + "table_name": "users", + "columns": [ + { + "column_name": "id", + "data_type": "int", + "is_nullable": "NO", + "column_default": null, + "column_comment": "主键ID", + "max_length": null, + "numeric_precision": 10, + "numeric_scale": 0 + }, + { + "column_name": "name", + "data_type": "varchar", + "is_nullable": "YES", + "column_default": null, + "column_comment": "用户名", + "max_length": 255, + "numeric_precision": null, + "numeric_scale": null + } + ] + }, + "error": null +} +``` + +说明:不同数据库的字段备注来源 +- MySQL: INFORMATION_SCHEMA.COLUMNS.COLUMN_COMMENT +- PostgreSQL: col_description + format_type +- SQL Server: sys.extended_properties('MS_Description') +- Oracle: user_col_comments + user_tab_columns diff --git a/main.py b/main.py new file mode 100644 index 0000000..839a5d8 --- /dev/null +++ b/main.py @@ -0,0 +1,122 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +import uvicorn +import logging +from contextlib import asynccontextmanager + +# 导入路由模块 +from api.v1.routes.database import router as database_router +from api.v1.routes.tables import router as tables_router +from database_manager import db_manager +from sample_data import SampleDataInitializer + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# 应用生命周期管理 +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用启动和关闭时的处理""" + # 启动时 + logger.info("数据库接口服务启动中...") + + # 初始化示例数据 + try: + logger.info("开始初始化示例数据...") + sample_initializer = SampleDataInitializer() + sample_initializer.initialize_all_sample_data() + logger.info("示例数据初始化完成") + except Exception as e: + logger.warning(f"示例数据初始化失败: {str(e)}") + logger.info("服务将继续启动,但示例数据不可用") + + yield + # 关闭时 + logger.info("数据库接口服务关闭中...") + # 关闭所有数据库连接 + for connection_id in db_manager.list_connections(): + db_manager.close_connection(connection_id) + logger.info("所有数据库连接已关闭") + +# 创建FastAPI应用 +app = FastAPI( + title="数据库接口服务", + description="提供统一的数据库管理接口,支持MySQL、Oracle、SQL Server、PostgreSQL等多种数据库类型", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + lifespan=lifespan +) + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 在生产环境中应该设置具体的域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 全局异常处理 +@app.exception_handler(Exception) +async def global_exception_handler(request, exc): + """全局异常处理器""" + logger.error(f"全局异常: {str(exc)}") + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": "服务器内部错误", + "error": str(exc) + } + ) + + +# 健康检查接口 +@app.get("/health", tags=["系统"]) +async def health_check(): + """健康检查接口""" + return { + "success": True, + "message": "服务运行正常", + "data": { + "status": "healthy", + "active_connections": len(db_manager.list_connections()) + } + } + + +# 根路径接口 +@app.get("/", tags=["系统"]) +async def root(): + """根路径接口""" + return { + "success": True, + "message": "欢迎使用数据库接口服务", + "data": { + "title": "数据库接口服务", + "version": "1.0.0", + "description": "提供统一的数据库管理接口,支持MySQL、Oracle、SQL Server、PostgreSQL等多种数据库类型", + "docs_url": "/docs", + "redoc_url": "/redoc" + } + } + +# 注册路由 +app.include_router(database_router, prefix="/api/v1", tags=["数据库管理"]) +app.include_router(tables_router, prefix="/api/v1", tags=["表管理"]) + +# 主函数 +if __name__ == "__main__": + logger.info("启动数据库接口服务...") + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + log_level="info" + ) diff --git a/models.py b/models.py new file mode 100644 index 0000000..d871fc4 --- /dev/null +++ b/models.py @@ -0,0 +1,120 @@ +from pydantic import BaseModel, Field +from typing import Dict, List, Any, Optional +from enum import Enum + +class DatabaseType(str, Enum): + """支持的数据库类型""" + MYSQL = "mysql" + ORACLE = "oracle" + SQLSERVER = "sqlserver" + POSTGRESQL = "postgresql" + +class DatabaseConnection(BaseModel): + """数据库连接配置""" + db_type: DatabaseType = Field(..., description="数据库类型") + host: str = Field(..., description="数据库主机地址") + port: int = Field(..., description="数据库端口") + username: str = Field(..., description="用户名") + password: str = Field(..., description="密码") + database: Optional[str] = Field(None, description="数据库名称") + # Oracle特定参数 + mode: Optional[str] = Field(None, description="Oracle连接模式") + threaded: Optional[bool] = Field(None, description="Oracle是否启用线程模式") + # 其他连接参数 + extra_params: Optional[Dict[str, Any]] = Field(None, description="额外的连接参数") + +class QueryRequest(BaseModel): + """查询请求""" + connection_id: str = Field(..., description="连接ID") + sql: str = Field(..., description="SQL语句") + params: Optional[Dict[str, Any]] = Field(None, description="SQL参数") + +class ExecuteRequest(BaseModel): + """执行请求""" + connection_id: str = Field(..., description="连接ID") + sql: str = Field(..., description="SQL语句") + params: Optional[Dict[str, Any]] = Field(None, description="SQL参数") + +class TableDataRequest(BaseModel): + """表数据请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + page: int = Field(1, description="页码") + page_size: int = Field(10, description="每页大小") + where_clause: Optional[str] = Field(None, description="WHERE条件") + order_by: Optional[str] = Field(None, description="排序字段") + +class InsertDataRequest(BaseModel): + """插入数据请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + data: Dict[str, Any] = Field(..., description="要插入的数据") + +class UpdateDataRequest(BaseModel): + """更新数据请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + data: Dict[str, Any] = Field(..., description="要更新的数据") + where_clause: str = Field(..., description="WHERE条件") + +class DeleteDataRequest(BaseModel): + """删除数据请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + where_clause: str = Field(..., description="WHERE条件") + +class CreateTableRequest(BaseModel): + """创建表请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + columns: List[Dict[str, Any]] = Field(..., description="列定义") + +class AlterTableRequest(BaseModel): + """修改表请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + operation: str = Field(..., description="操作类型: ADD, DROP, MODIFY") + column_definition: Optional[Dict[str, Any]] = Field(None, description="列定义") + +class CommentRequest(BaseModel): + """备注请求""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + column_name: Optional[str] = Field(None, description="列名(为空则修改表备注)") + comment: str = Field(..., description="备注内容") + +class ApiResponse(BaseModel): + """API响应格式""" + success: bool = Field(..., description="是否成功") + message: str = Field(..., description="响应消息") + data: Optional[Any] = Field(None, description="响应数据") + error: Optional[str] = Field(None, description="错误信息") + +class ConnectionResponse(BaseModel): + """连接响应""" + connection_id: str = Field(..., description="连接ID") + db_type: str = Field(..., description="数据库类型") + host: str = Field(..., description="主机地址") + port: int = Field(..., description="端口") + database: Optional[str] = Field(None, description="数据库名称") + +class DatabaseInfo(BaseModel): + """数据库信息""" + database_name: str = Field(..., description="数据库名称") + tables: List[str] = Field(..., description="表列表") + table_count: int = Field(..., description="表数量") + +class TableInfo(BaseModel): + """表信息""" + table_name: str = Field(..., description="表名") + columns: List[Dict[str, Any]] = Field(..., description="列信息") + primary_keys: Dict[str, Any] = Field(..., description="主键信息") + foreign_keys: List[Dict[str, Any]] = Field(..., description="外键信息") + indexes: List[Dict[str, Any]] = Field(..., description="索引信息") + +class QueryResult(BaseModel): + """查询结果""" + data: List[Dict[str, Any]] = Field(..., description="查询数据") + total: int = Field(..., description="总记录数") + page: int = Field(..., description="当前页码") + page_size: int = Field(..., description="每页大小") \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..5e56de8 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,6 @@ +"""SQLAlchemy ORM 模型包 + +存放项目中使用到的SQLAlchemy模型。当前主要通过运行时创建与查询, +如需持久化ORM模型可在此处新增。 +""" + diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..c9350c2 --- /dev/null +++ b/models/base.py @@ -0,0 +1,16 @@ +"""SQLAlchemy Base 定义与示例模型 + +提供 SQLAlchemy 的 Base 供后续 ORM 模型继承。 +""" + +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + +# 示例:如需添加ORM模型,可参考以下结构 +# from sqlalchemy import Column, Integer, String +# class User(Base): +# __tablename__ = "users" +# id = Column(Integer, primary_key=True, autoincrement=True) +# name = Column(String(100), nullable=False) + diff --git a/quick_oracle_test.py b/quick_oracle_test.py new file mode 100644 index 0000000..f0e4977 --- /dev/null +++ b/quick_oracle_test.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +快速Oracle连接测试脚本 +用于快速诊断Oracle连接问题 +""" + +import oracledb +import logging +from config import DatabaseConfig + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def test_oracle_connection(): + """ + 测试Oracle连接的多种方式 + """ + # 获取Oracle配置 + config = DatabaseConfig.get_config("oracle") + + print("=" * 60) + print("快速Oracle连接测试") + print("=" * 60) + print(f"配置信息:") + print(f" 主机: {config['host']}") + print(f" 端口: {config['port']}") + print(f" 用户名: {config['username']}") + print(f" 服务名: {config['service_name']}") + print("=" * 60) + + # 测试方式列表 + test_methods = [ + { + 'name': '方式1: Easy Connect字符串', + 'params': { + 'user': config['username'], + 'password': config['password'], + 'dsn': f"{config['host']}:{config['port']}/{config['service_name']}" + } + }, + { + 'name': '方式2: 分离参数 (service_name)', + 'params': { + 'user': config['username'], + 'password': config['password'], + 'host': config['host'], + 'port': config['port'], + 'service_name': config['service_name'] + } + }, + { + 'name': '方式3: makedsn (service_name)', + 'params': { + 'user': config['username'], + 'password': config['password'], + 'dsn': oracledb.makedsn(config['host'], config['port'], service_name=config['service_name']) + } + }, + { + 'name': '方式4: makedsn (SID)', + 'params': { + 'user': config['username'], + 'password': config['password'], + 'dsn': oracledb.makedsn(config['host'], config['port'], sid=config['service_name']) + } + } + ] + + success_count = 0 + + for i, method in enumerate(test_methods, 1): + print(f"\n🔍 测试 {method['name']}...") + + try: + # 显示连接参数 + if 'dsn' in method['params']: + print(f" DSN: {method['params']['dsn']}") + else: + print(f" Host: {method['params']['host']}:{method['params']['port']}") + if 'service_name' in method['params']: + print(f" Service Name: {method['params']['service_name']}") + + # 尝试连接 + connection = oracledb.connect(**method['params']) + + # 测试查询 + cursor = connection.cursor() + cursor.execute("SELECT 1 FROM DUAL") + result = cursor.fetchone() + + # 获取数据库信息 + cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1") + version = cursor.fetchone() + + cursor.close() + connection.close() + + print(f" ✅ 连接成功!") + print(f" 📊 查询结果: {result}") + print(f" 🔖 数据库版本: {version[0] if version else 'Unknown'}") + + success_count += 1 + + except Exception as e: + print(f" ❌ 连接失败: {str(e)}") + + # 分析常见错误 + error_str = str(e).lower() + if 'ora-12514' in error_str: + print(f" 💡 提示: ORA-12514错误通常表示服务名不存在或未注册") + elif 'ora-12541' in error_str: + print(f" 💡 提示: ORA-12541错误通常表示监听器未运行") + elif 'ora-01017' in error_str: + print(f" 💡 提示: ORA-01017错误表示用户名或密码无效") + elif 'ora-12170' in error_str: + print(f" 💡 提示: ORA-12170错误表示连接超时") + + print("\n" + "=" * 60) + print(f"测试结果: {success_count}/{len(test_methods)} 种方式成功") + + if success_count == 0: + print("\n🔧 所有连接方式都失败,可能的解决方案:") + print("1. 检查Oracle数据库服务是否正在运行") + print("2. 检查监听器状态: lsnrctl status") + print("3. 验证服务名是否正确注册") + print("4. 检查网络连接和防火墙设置") + print("5. 确认用户名和密码正确") + print("6. 尝试使用Oracle SQL Developer或其他客户端工具测试连接") + elif success_count < len(test_methods): + print("\n⚠️ 部分连接方式成功,建议使用成功的连接方式") + else: + print("\n🎉 所有连接方式都成功!") + + print("=" * 60) + + return success_count > 0 + +if __name__ == "__main__": + test_oracle_connection() \ No newline at end of file diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..29e44cf --- /dev/null +++ b/readme.md @@ -0,0 +1,68 @@ +## 这是一个数据库接口服务 +- 主要功能是将多种类型的数据库集中到一起,提供统一的接口 +- 支持的数据库类型包括mysql、oracle、sqlserver、postgresql等 +## 主要功能 +- 提供的数据库管理功能,通过传入数据库类型,ip和端口以及用户名和密码,来连接数据库 +- 可以获取到数据库的信息,包括数据库的名称,数据库中的表,字段类型,数据库的字段备注等 +- 提供接口,可以对数据库中的数据进行查操作 +- 提供接口,可以对数据库中的表进行查操作 +- 提供接口,可以对数据库中的字段进行查操作 +- 提供接口,可以修改数据库本身以及字段的备注信息 +## 主要接口包括 +- 数据库管理接口 + - 测试能否连通数据库 + - 获取数据库信息 + - 获取某个数据库中所有的数据库表和表备注信息 + - 获取数据表中字段名和类型以及备注信息 + +## 项目启动步骤 + +### 1. 创建并激活conda环境 +```bash +# 创建Python 3.11环境 +conda create -n database-etl python=3.11 + +# 激活环境 +conda activate database-etl +``` + +### 2. 安装项目依赖 +```bash +# 安装所有依赖包 +pip install -r requirements.txt +``` + +### 3. 启动项目 +```bash +# 启动FastAPI服务 +python main.py +``` + +### 4. 访问服务 +- 服务地址:http://localhost:8000 +- API文档:http://localhost:8000/docs +- ReDoc文档:http://localhost:8000/redoc +- 健康检查:http://localhost:8000/health + +## 主要技术栈 +通过fastapi启动api服务,使用sqlalchemy来创建连接引擎,对不同的数据库使用不同的驱动,mysql使用PyMySQL,Oracle使用oracledb,sqlserver使用pymssql,postgresql使用psycopg2 + +## 其他 +- 主要要将一些基础功能进行封装,例如创建数据库引擎,执行sql语句等 +- 提供的接口要符合restful风格 +- 提供的接口要符合http协议的规范 +- 提供的接口要符合json格式 +- 提供的接口要符合http状态码的规范 +- 提供的接口要符合http头的规范 +- 提供的接口要符合http请求体的规范 +- 提供的接口要符合http响应体的规范 +- main.py作为程序的入口 + +## 接口风格调整 +- GET接口参数统一改为使用URL的query传参,例如 `?a=123&b=321` +- 所有PUT和DELETE接口统一改为POST方法,路径保持不变(除GET去除路径参数外) +- 示例: + - `GET /api/v1/databases/info?connection_id=` + - `GET /api/v1/databases/tables/info?connection_id=&table_name=
` + - `POST /api/v1/tables/data/update` + - `POST /api/v1/tables/data/delete` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8cef1e7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +sqlalchemy==2.0.23 +PyMySQL==1.1.0 +oracledb +pymssql==2.2.8 +psycopg2-binary==2.9.9 +pydantic==2.5.0 +python-multipart==0.0.6 +cryptography==41.0.7 +python-dotenv==1.0.0 \ No newline at end of file diff --git a/sample_data.py b/sample_data.py new file mode 100644 index 0000000..dc4dd8e --- /dev/null +++ b/sample_data.py @@ -0,0 +1,336 @@ +from database_manager import DatabaseManager +from config import DatabaseConfig +import logging + +logger = logging.getLogger(__name__) + +class SampleDataInitializer: + """示例数据初始化器""" + + def __init__(self): + self.db_manager = DatabaseManager() + + def init_mysql_sample_data(self, config: dict = None): + """初始化MySQL示例数据""" + try: + # 获取配置 + if config is None: + config = DatabaseConfig.get_config("mysql") + + # 创建连接 + conn_id = self.db_manager.create_connection( + db_type="mysql", + host=config["host"], + port=config["port"], + username=config["username"], + password=config["password"], + database=config["database"] + ) + + # 创建示例表 + create_users_table = """ + CREATE TABLE IF NOT EXISTS users ( + id INT AUTO_INCREMENT PRIMARY KEY COMMENT '用户ID', + name VARCHAR(100) NOT NULL COMMENT '用户姓名', + email VARCHAR(150) UNIQUE NOT NULL COMMENT '邮箱地址', + age INT COMMENT '年龄', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间' + ) COMMENT='用户信息表'; + """ + + create_products_table = """ + CREATE TABLE IF NOT EXISTS products ( + id INT AUTO_INCREMENT PRIMARY KEY COMMENT '产品ID', + name VARCHAR(200) NOT NULL COMMENT '产品名称', + price DECIMAL(10,2) NOT NULL COMMENT '价格', + category VARCHAR(100) COMMENT '分类', + description TEXT COMMENT '产品描述', + stock_quantity INT DEFAULT 0 COMMENT '库存数量', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间' + ) COMMENT='产品信息表'; + """ + + # 执行建表语句 + self.db_manager.execute_non_query(conn_id, create_users_table) + self.db_manager.execute_non_query(conn_id, create_products_table) + + # 插入示例数据 + insert_users = """ + INSERT IGNORE INTO users (name, email, age) VALUES + ('张三', 'zhangsan@example.com', 25), + ('李四', 'lisi@example.com', 30), + ('王五', 'wangwu@example.com', 28), + ('赵六', 'zhaoliu@example.com', 35), + ('钱七', 'qianqi@example.com', 22); + """ + + insert_products = """ + INSERT IGNORE INTO products (name, price, category, description, stock_quantity) VALUES + ('苹果手机', 5999.00, '电子产品', '最新款智能手机', 50), + ('笔记本电脑', 8999.00, '电子产品', '高性能办公笔记本', 30), + ('无线耳机', 299.00, '电子产品', '蓝牙无线耳机', 100), + ('咖啡杯', 39.90, '生活用品', '陶瓷咖啡杯', 200), + ('书包', 129.00, '生活用品', '学生书包', 80); + """ + + self.db_manager.execute_non_query(conn_id, insert_users) + self.db_manager.execute_non_query(conn_id, insert_products) + + logger.info(f"MySQL示例数据初始化成功: {conn_id}") + return conn_id + + except Exception as e: + logger.error(f"MySQL示例数据初始化失败: {str(e)}") + return None + + def init_oracle_sample_data(self, config: dict = None): + """初始化Oracle示例数据""" + try: + # 获取配置 + if config is None: + config = DatabaseConfig.get_config("oracle") + + logger.info(f"开始初始化Oracle示例数据,配置: host={config['host']}, port={config['port']}, service_name={config['service_name']}") + + # 创建连接,传递额外参数 + conn_id = self.db_manager.create_connection( + db_type="oracle", + host=config["host"], + port=config["port"], + username=config["username"], + password=config["password"], + database=config["service_name"], + # 添加Oracle特定参数 + mode=config.get("mode"), + threaded=config.get("threaded", True) + ) + + logger.info(f"Oracle连接创建成功,连接ID: {conn_id}") + + # 创建示例表 + create_employees_table = """ + CREATE TABLE employees ( + employee_id NUMBER PRIMARY KEY, + first_name VARCHAR2(50) NOT NULL, + last_name VARCHAR2(50) NOT NULL, + email VARCHAR2(100) UNIQUE NOT NULL, + phone_number VARCHAR2(20), + hire_date DATE DEFAULT SYSDATE, + job_id VARCHAR2(10), + salary NUMBER(8,2), + department_id NUMBER + ) + """ + + create_departments_table = """ + CREATE TABLE departments ( + department_id NUMBER PRIMARY KEY, + department_name VARCHAR2(100) NOT NULL, + manager_id NUMBER, + location_id NUMBER + ) + """ + + # 创建序列 + create_emp_seq = "CREATE SEQUENCE emp_seq START WITH 1 INCREMENT BY 1" + create_dept_seq = "CREATE SEQUENCE dept_seq START WITH 1 INCREMENT BY 1" + + try: + self.db_manager.execute_non_query(conn_id, "DROP TABLE employees") + self.db_manager.execute_non_query(conn_id, "DROP TABLE departments") + self.db_manager.execute_non_query(conn_id, "DROP SEQUENCE emp_seq") + self.db_manager.execute_non_query(conn_id, "DROP SEQUENCE dept_seq") + except: + pass # 忽略删除错误 + + # 执行建表和序列语句 + self.db_manager.execute_non_query(conn_id, create_departments_table) + self.db_manager.execute_non_query(conn_id, create_employees_table) + self.db_manager.execute_non_query(conn_id, create_dept_seq) + self.db_manager.execute_non_query(conn_id, create_emp_seq) + + # 插入示例数据 + insert_departments = """ + INSERT INTO departments (department_id, department_name, manager_id, location_id) VALUES + (dept_seq.NEXTVAL, '人力资源部', NULL, 1700) + """ + + insert_departments2 = """ + INSERT INTO departments (department_id, department_name, manager_id, location_id) VALUES + (dept_seq.NEXTVAL, '技术部', NULL, 1800) + """ + + insert_departments3 = """ + INSERT INTO departments (department_id, department_name, manager_id, location_id) VALUES + (dept_seq.NEXTVAL, '销售部', NULL, 1900) + """ + + self.db_manager.execute_non_query(conn_id, insert_departments) + self.db_manager.execute_non_query(conn_id, insert_departments2) + self.db_manager.execute_non_query(conn_id, insert_departments3) + + insert_employees = """ + INSERT INTO employees (employee_id, first_name, last_name, email, phone_number, job_id, salary, department_id) VALUES + (emp_seq.NEXTVAL, '张', '三', 'zhang.san@company.com', '13800138001', 'IT_PROG', 8000, 2) + """ + + insert_employees2 = """ + INSERT INTO employees (employee_id, first_name, last_name, email, phone_number, job_id, salary, department_id) VALUES + (emp_seq.NEXTVAL, '李', '四', 'li.si@company.com', '13800138002', 'SA_REP', 6000, 3) + """ + + insert_employees3 = """ + INSERT INTO employees (employee_id, first_name, last_name, email, phone_number, job_id, salary, department_id) VALUES + (emp_seq.NEXTVAL, '王', '五', 'wang.wu@company.com', '13800138003', 'HR_REP', 5500, 1) + """ + + self.db_manager.execute_non_query(conn_id, insert_employees) + self.db_manager.execute_non_query(conn_id, insert_employees2) + self.db_manager.execute_non_query(conn_id, insert_employees3) + + # 添加表注释 + self.db_manager.execute_non_query(conn_id, "COMMENT ON TABLE employees IS '员工信息表'") + self.db_manager.execute_non_query(conn_id, "COMMENT ON TABLE departments IS '部门信息表'") + + # 添加列注释 + self.db_manager.execute_non_query(conn_id, "COMMENT ON COLUMN employees.employee_id IS '员工ID'") + self.db_manager.execute_non_query(conn_id, "COMMENT ON COLUMN employees.first_name IS '名'") + self.db_manager.execute_non_query(conn_id, "COMMENT ON COLUMN employees.last_name IS '姓'") + self.db_manager.execute_non_query(conn_id, "COMMENT ON COLUMN employees.email IS '邮箱地址'") + self.db_manager.execute_non_query(conn_id, "COMMENT ON COLUMN employees.salary IS '薪资'") + + logger.info(f"Oracle示例数据初始化成功: {conn_id}") + return conn_id + + except Exception as e: + logger.error(f"Oracle示例数据初始化失败: {str(e)}") + return None + + def init_sqlserver_sample_data(self, config: dict = None): + """初始化SQL Server示例数据""" + try: + # 获取配置 + if config is None: + config = DatabaseConfig.get_config("sqlserver") + + logger.info(f"开始初始化SQL Server示例数据,配置: host={config['host']}, port={config['port']}, database={config['database']}") + + # 创建连接 + conn_id = self.db_manager.create_connection( + db_type="sqlserver", + host=config["host"], + port=config["port"], + username=config["username"], + password=config["password"], + database=config["database"] + ) + + logger.info(f"SQL Server连接创建成功,连接ID: {conn_id}") + + # 创建示例表 + create_customers_table = """ + IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='customers' AND xtype='U') + CREATE TABLE customers ( + customer_id INT IDENTITY(1,1) PRIMARY KEY, + company_name NVARCHAR(100) NOT NULL, + contact_name NVARCHAR(50), + contact_title NVARCHAR(30), + address NVARCHAR(100), + city NVARCHAR(50), + region NVARCHAR(50), + postal_code NVARCHAR(20), + country NVARCHAR(50), + phone NVARCHAR(30), + email NVARCHAR(100), + created_date DATETIME DEFAULT GETDATE() + ) + """ + + create_orders_table = """ + IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='orders' AND xtype='U') + CREATE TABLE orders ( + order_id INT IDENTITY(1,1) PRIMARY KEY, + customer_id INT, + order_date DATETIME DEFAULT GETDATE(), + required_date DATETIME, + shipped_date DATETIME, + ship_via INT, + freight DECIMAL(10,2), + ship_name NVARCHAR(100), + ship_address NVARCHAR(100), + ship_city NVARCHAR(50), + ship_region NVARCHAR(50), + ship_postal_code NVARCHAR(20), + ship_country NVARCHAR(50), + FOREIGN KEY (customer_id) REFERENCES customers(customer_id) + ) + """ + + # 执行建表语句 + self.db_manager.execute_non_query(conn_id, create_customers_table) + self.db_manager.execute_non_query(conn_id, create_orders_table) + + # 插入示例数据 - 客户表 + customers_data = [ + "INSERT INTO customers (company_name, contact_name, contact_title, address, city, region, postal_code, country, phone, email) VALUES ('北京科技有限公司', '张三', '总经理', '北京市朝阳区建国路1号', '北京', '华北', '100001', '中国', '010-12345678', 'zhangsan@bjtech.com')", + "INSERT INTO customers (company_name, contact_name, contact_title, address, city, region, postal_code, country, phone, email) VALUES ('上海贸易公司', '李四', '销售经理', '上海市浦东新区陆家嘴路100号', '上海', '华东', '200001', '中国', '021-87654321', 'lisi@shtrade.com')", + "INSERT INTO customers (company_name, contact_name, contact_title, address, city, region, postal_code, country, phone, email) VALUES ('广州制造企业', '王五', '采购主管', '广州市天河区珠江路200号', '广州', '华南', '510001', '中国', '020-11223344', 'wangwu@gzmfg.com')", + "INSERT INTO customers (company_name, contact_name, contact_title, address, city, region, postal_code, country, phone, email) VALUES ('深圳创新公司', '赵六', '技术总监', '深圳市南山区科技园300号', '深圳', '华南', '518001', '中国', '0755-99887766', 'zhaoliu@szinno.com')", + "INSERT INTO customers (company_name, contact_name, contact_title, address, city, region, postal_code, country, phone, email) VALUES ('成都服务公司', '钱七', '客户经理', '成都市锦江区春熙路400号', '成都', '西南', '610001', '中国', '028-55443322', 'qianqi@cdservice.com')" + ] + + for sql in customers_data: + self.db_manager.execute_non_query(conn_id, sql) + + # 插入示例数据 - 订单表 + orders_data = [ + "INSERT INTO orders (customer_id, required_date, freight, ship_name, ship_address, ship_city, ship_region, ship_postal_code, ship_country) VALUES (1, DATEADD(day, 7, GETDATE()), 25.50, '北京科技有限公司', '北京市朝阳区建国路1号', '北京', '华北', '100001', '中国')", + "INSERT INTO orders (customer_id, required_date, freight, ship_name, ship_address, ship_city, ship_region, ship_postal_code, ship_country) VALUES (2, DATEADD(day, 10, GETDATE()), 35.75, '上海贸易公司', '上海市浦东新区陆家嘴路100号', '上海', '华东', '200001', '中国')", + "INSERT INTO orders (customer_id, required_date, freight, ship_name, ship_address, ship_city, ship_region, ship_postal_code, ship_country) VALUES (3, DATEADD(day, 5, GETDATE()), 18.25, '广州制造企业', '广州市天河区珠江路200号', '广州', '华南', '510001', '中国')", + "INSERT INTO orders (customer_id, required_date, freight, ship_name, ship_address, ship_city, ship_region, ship_postal_code, ship_country) VALUES (4, DATEADD(day, 14, GETDATE()), 42.00, '深圳创新公司', '深圳市南山区科技园300号', '深圳', '华南', '518001', '中国')", + "INSERT INTO orders (customer_id, required_date, freight, ship_name, ship_address, ship_city, ship_region, ship_postal_code, ship_country) VALUES (5, DATEADD(day, 12, GETDATE()), 28.90, '成都服务公司', '成都市锦江区春熙路400号', '成都', '西南', '610001', '中国')" + ] + + for sql in orders_data: + self.db_manager.execute_non_query(conn_id, sql) + + logger.info(f"SQL Server示例数据初始化成功: {conn_id}") + return conn_id + + except Exception as e: + logger.error(f"SQL Server示例数据初始化失败: {str(e)}") + return None + + def initialize_all_sample_data(self): + """初始化所有示例数据""" + # 检查是否启用示例数据 + if not DatabaseConfig.is_sample_data_enabled(): + logger.info("示例数据初始化已禁用") + return {"mysql": None, "oracle": None, "sqlserver": None} + + logger.info("开始初始化示例数据...") + + # 初始化MySQL示例数据 + mysql_conn = self.init_mysql_sample_data() + if mysql_conn: + logger.info("MySQL示例数据初始化完成") + else: + logger.warning("MySQL示例数据初始化失败,请检查数据库连接配置") + + # 初始化Oracle示例数据 + oracle_conn = self.init_oracle_sample_data() + if oracle_conn: + logger.info("Oracle示例数据初始化完成") + else: + logger.warning("Oracle示例数据初始化失败,请检查数据库连接配置") + + # 初始化SQL Server示例数据 + sqlserver_conn = self.init_sqlserver_sample_data() + if sqlserver_conn: + logger.info("SQL Server示例数据初始化完成") + else: + logger.warning("SQL Server示例数据初始化失败,请检查数据库连接配置") + + logger.info("示例数据初始化流程完成") + return {"mysql": mysql_conn, "oracle": oracle_conn, "sqlserver": sqlserver_conn} \ No newline at end of file diff --git a/schemas/__init__.py b/schemas/__init__.py new file mode 100644 index 0000000..eabbc6e --- /dev/null +++ b/schemas/__init__.py @@ -0,0 +1,142 @@ +"""Pydantic 模型定义 + +此包包含所有请求/响应的 Pydantic 模型与枚举。 +""" + +from pydantic import BaseModel, Field +from typing import Dict, List, Any, Optional +from enum import Enum + + +class DatabaseType(str, Enum): + """支持的数据库类型枚举""" + MYSQL = "mysql" + ORACLE = "oracle" + SQLSERVER = "sqlserver" + POSTGRESQL = "postgresql" + + +class DatabaseConnection(BaseModel): + """数据库连接配置模型""" + db_type: DatabaseType = Field(..., description="数据库类型") + host: str = Field(..., description="数据库主机地址") + port: int = Field(..., description="数据库端口") + username: str = Field(..., description="用户名") + password: str = Field(..., description="密码") + database: Optional[str] = Field(None, description="数据库名称") + # Oracle特定参数 + mode: Optional[str] = Field(None, description="Oracle连接模式") + threaded: Optional[bool] = Field(None, description="Oracle是否启用线程模式") + # 其他连接参数 + extra_params: Optional[Dict[str, Any]] = Field(None, description="额外的连接参数") + + +class QueryRequest(BaseModel): + """查询请求模型""" + connection_id: str = Field(..., description="连接ID") + sql: str = Field(..., description="SQL语句") + params: Optional[Dict[str, Any]] = Field(None, description="SQL参数") + + +class ExecuteRequest(BaseModel): + """非查询执行请求模型""" + connection_id: str = Field(..., description="连接ID") + sql: str = Field(..., description="SQL语句") + params: Optional[Dict[str, Any]] = Field(None, description="SQL参数") + + +class TableDataRequest(BaseModel): + """表数据查询请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + page: int = Field(1, description="页码") + page_size: int = Field(10, description="每页大小") + where_clause: Optional[str] = Field(None, description="WHERE条件") + order_by: Optional[str] = Field(None, description="排序字段") + + +class InsertDataRequest(BaseModel): + """插入数据请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + data: Dict[str, Any] = Field(..., description="要插入的数据") + + +class UpdateDataRequest(BaseModel): + """更新数据请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + data: Dict[str, Any] = Field(..., description="要更新的数据") + where_clause: str = Field(..., description="WHERE条件") + + +class DeleteDataRequest(BaseModel): + """删除数据请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + where_clause: str = Field(..., description="WHERE条件") + + +class CreateTableRequest(BaseModel): + """创建表请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + columns: List[Dict[str, Any]] = Field(..., description="列定义") + + +class AlterTableRequest(BaseModel): + """修改表结构请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + operation: str = Field(..., description="操作类型: ADD, DROP, MODIFY") + column_definition: Optional[Dict[str, Any]] = Field(None, description="列定义") + + +class CommentRequest(BaseModel): + """修改备注请求模型""" + connection_id: str = Field(..., description="连接ID") + table_name: str = Field(..., description="表名") + column_name: Optional[str] = Field(None, description="列名(为空则修改表备注)") + comment: str = Field(..., description="备注内容") + + +class ApiResponse(BaseModel): + """统一API响应模型""" + success: bool = Field(..., description="是否成功") + message: str = Field(..., description="响应消息") + data: Optional[Any] = Field(None, description="响应数据") + error: Optional[str] = Field(None, description="错误信息") + + +class ConnectionResponse(BaseModel): + """连接响应模型""" + connection_id: str = Field(..., description="连接ID") + db_type: str = Field(..., description="数据库类型") + host: str = Field(..., description="主机地址") + port: int = Field(..., description="端口") + database: Optional[str] = Field(None, description="数据库名称") + + +class DatabaseInfo(BaseModel): + """数据库信息响应模型""" + database_name: str = Field(..., description="数据库名称") + tables: List[str] = Field(..., description="表列表") + table_count: int = Field(..., description="表数量") + + +class TableInfo(BaseModel): + """表信息响应模型""" + table_name: str = Field(..., description="表名") + columns: List[Dict[str, Any]] = Field(..., description="列信息") + primary_keys: Dict[str, Any] = Field(..., description="主键信息") + foreign_keys: List[Dict[str, Any]] = Field(..., description="外键信息") + indexes: List[Dict[str, Any]] = Field(..., description="索引信息") + + +class QueryResult(BaseModel): + """查询结果响应模型""" + data: List[Dict[str, Any]] = Field(..., description="查询数据") + total: int = Field(..., description="总记录数") + page: int = Field(..., description="当前页码") + page_size: int = Field(..., description="每页大小") + diff --git a/schemas/__pycache__/__init__.cpython-311.pyc b/schemas/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..2af9eb0 Binary files /dev/null and b/schemas/__pycache__/__init__.cpython-311.pyc differ diff --git a/start.bat b/start.bat new file mode 100644 index 0000000..5fa68d5 --- /dev/null +++ b/start.bat @@ -0,0 +1,65 @@ +@echo off +chcp 65001 >nul +echo ======================================== +echo 数据库接口服务启动脚本 +echo ======================================== +echo. + +echo [1/4] 检查conda环境... +call conda info --envs | findstr "database-etl" >nul +if %errorlevel% neq 0 ( + echo 环境不存在,正在创建conda环境... + call conda create -n database-etl python=3.11 -y + if %errorlevel% neq 0 ( + echo 创建环境失败! + pause + exit /b 1 + ) + echo 环境创建成功! +) else ( + echo conda环境已存在 +) +echo. + +echo [2/4] 激活conda环境... +call conda activate database-etl +if %errorlevel% neq 0 ( + echo 激活环境失败! + pause + exit /b 1 +) +echo 环境激活成功! +echo. + +echo [3/4] 安装项目依赖... +if exist requirements.txt ( + pip install -r requirements.txt + if %errorlevel% neq 0 ( + echo 依赖安装失败! + pause + exit /b 1 + ) + echo 依赖安装成功! +) else ( + echo 警告:未找到requirements.txt文件 + pause +) +echo. + +echo [4/4] 启动项目服务... +echo 正在启动FastAPI服务... +echo 服务地址: http://localhost:8000 +echo API文档: http://localhost:8000/docs +echo 按Ctrl+C停止服务 +echo. +if not exist main.py ( + echo 错误:未找到main.py文件! + pause + exit /b 1 +) + +python main.py + +echo. +echo 服务已停止 +pause \ No newline at end of file diff --git a/test_oracle_connection.py b/test_oracle_connection.py new file mode 100644 index 0000000..d8a6cd8 --- /dev/null +++ b/test_oracle_connection.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Oracle数据库连接测试脚本 +用于测试Oracle数据库连接是否正常 +""" + +import requests +import json +import oracledb +from typing import Dict, Any +import logging + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Oracle连接配置 +ORACLE_CONFIG = { + "db_type": "oracle", + "host": "192.168.13.200", + "port": 1521, + "username": "bizuser", + "password": "MySecurePass123", + "database": "ORCLPDB1", # 服务名称 + "mode": None, # 可以设置为 "SYSDBA", "SYSOPER" 等 + "threaded": True, + "extra_params": { + # 可以添加其他Oracle特定参数 + } +} + +# API基础URL +BASE_URL = "http://localhost:8000" + +def test_oracle_connection(config: Dict[str, Any]) -> bool: + """ + 测试Oracle数据库连接 + + Args: + config: Oracle连接配置 + + Returns: + bool: 连接是否成功 + """ + try: + # 创建连接 + print("正在测试Oracle连接...") + print(f"连接信息: {config['host']}:{config['port']}/{config['database']}") + print(f"用户名: {config['username']}") + + response = requests.post( + f"{BASE_URL}/connections", + json=config, + headers={"Content-Type": "application/json"} + ) + + if response.status_code == 200: + result = response.json() + if result.get("success"): + connection_id = result["data"]["connection_id"] + print(f"✅ Oracle连接成功! 连接ID: {connection_id}") + + # 测试查询 + test_query(connection_id) + + # 关闭连接 + close_connection(connection_id) + return True + else: + print(f"❌ 连接失败: {result.get('message', '未知错误')}") + if result.get('error'): + print(f"错误详情: {result['error']}") + return False + else: + print(f"❌ HTTP请求失败: {response.status_code}") + print(f"响应内容: {response.text}") + return False + + except Exception as e: + print(f"❌ 连接测试异常: {str(e)}") + return False + +def test_query(connection_id: str): + """ + 测试查询操作 + + Args: + connection_id: 连接ID + """ + try: + print("\n正在测试查询操作...") + + query_request = { + "connection_id": connection_id, + "sql": "SELECT 1 FROM DUAL" + } + + response = requests.post( + f"{BASE_URL}/query", + json=query_request, + headers={"Content-Type": "application/json"} + ) + + if response.status_code == 200: + result = response.json() + if result.get("success"): + print("✅ 查询测试成功!") + print(f"查询结果: {result['data']}") + else: + print(f"❌ 查询失败: {result.get('message', '未知错误')}") + else: + print(f"❌ 查询请求失败: {response.status_code}") + + except Exception as e: + print(f"❌ 查询测试异常: {str(e)}") + +def close_connection(connection_id: str): + """ + 关闭数据库连接 + + Args: + connection_id: 连接ID + """ + try: + print("\n正在关闭连接...") + + response = requests.delete( + f"{BASE_URL}/connections/{connection_id}" + ) + + if response.status_code == 200: + result = response.json() + if result.get("success"): + print("✅ 连接已关闭") + else: + print(f"❌ 关闭连接失败: {result.get('message', '未知错误')}") + else: + print(f"❌ 关闭连接请求失败: {response.status_code}") + + except Exception as e: + print(f"❌ 关闭连接异常: {str(e)}") + +def test_direct_oracle_connection(config: Dict[str, Any]) -> bool: + """ + 直接测试Oracle连接(不通过API) + + Args: + config: Oracle连接配置 + + Returns: + bool: 连接是否成功 + """ + try: + print("\n🔍 直接测试Oracle连接...") + + # 方式1: 使用Easy Connect字符串 + dsn1 = f"{config['host']}:{config['port']}/{config['database']}" + print(f"尝试连接方式1 - Easy Connect: {dsn1}") + + connection = oracledb.connect( + user=config['username'], + password=config['password'], + dsn=dsn1 + ) + + # 测试查询 + cursor = connection.cursor() + cursor.execute("SELECT 1 FROM DUAL") + result = cursor.fetchone() + print(f"✅ 直接连接成功! 查询结果: {result}") + + # 获取数据库版本信息 + cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1") + version = cursor.fetchone() + print(f"📊 数据库版本: {version[0] if version else 'Unknown'}") + + cursor.close() + connection.close() + + return True + + except Exception as e: + print(f"❌ 直接连接失败: {str(e)}") + + # 尝试其他连接方式 + try: + print("\n🔄 尝试其他连接方式...") + + # 方式2: 使用分离的参数 + print(f"尝试连接方式2 - 分离参数: host={config['host']}, port={config['port']}, service_name={config['database']}") + + connection = oracledb.connect( + user=config['username'], + password=config['password'], + host=config['host'], + port=config['port'], + service_name=config['database'] + ) + + cursor = connection.cursor() + cursor.execute("SELECT 1 FROM DUAL") + result = cursor.fetchone() + print(f"✅ 方式2连接成功! 查询结果: {result}") + + cursor.close() + connection.close() + + return True + + except Exception as e2: + print(f"❌ 方式2也失败: {str(e2)}") + + # 尝试使用SID而不是服务名 + try: + print("\n🔄 尝试使用SID连接...") + + # 方式3: 使用SID + dsn3 = oracledb.makedsn(config['host'], config['port'], sid=config['database']) + print(f"尝试连接方式3 - SID: {dsn3}") + + connection = oracledb.connect( + user=config['username'], + password=config['password'], + dsn=dsn3 + ) + + cursor = connection.cursor() + cursor.execute("SELECT 1 FROM DUAL") + result = cursor.fetchone() + print(f"✅ SID连接成功! 查询结果: {result}") + + cursor.close() + connection.close() + + return True + + except Exception as e3: + print(f"❌ SID连接也失败: {str(e3)}") + return False + +def main(): + """ + 主函数 + """ + print("=" * 60) + print("Oracle数据库连接测试 - 增强版") + print("=" * 60) + + # 首先进行直接连接测试 + direct_success = test_direct_oracle_connection(ORACLE_CONFIG) + + if not direct_success: + print("\n💥 直接Oracle连接失败!") + print("\n🔧 可能的解决方案:") + print("1. 检查Oracle服务是否正在运行") + print("2. 检查网络连接和防火墙设置") + print("3. 验证用户名、密码和服务名称") + print("4. 确认Oracle客户端库已正确安装: pip install oracledb") + print("5. 检查Oracle监听器配置: lsnrctl status") + print("6. 尝试使用SID而不是服务名") + print("7. 检查服务名是否正确注册到监听器") + print("=" * 60) + return + + # 检查API服务是否运行 + print("\n🌐 检查API服务状态...") + try: + response = requests.get(f"{BASE_URL}/docs") + if response.status_code != 200: + print("❌ API服务未运行,请先启动服务: python main.py") + return + except requests.exceptions.ConnectionError: + print("❌ 无法连接到API服务,请先启动服务: python main.py") + return + + # 测试通过API的Oracle连接 + print("\n🔗 测试通过API的Oracle连接...") + api_success = test_oracle_connection(ORACLE_CONFIG) + + print("\n" + "=" * 60) + if direct_success and api_success: + print("🎉 所有Oracle连接测试通过!") + elif direct_success: + print("⚠️ 直接连接成功,但API连接失败") + print("请检查API服务中的Oracle连接配置") + else: + print("💥 Oracle连接测试失败!") + print("=" * 60) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_url_encoding.py b/test_url_encoding.py new file mode 100644 index 0000000..53c3676 --- /dev/null +++ b/test_url_encoding.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +URL编码测试脚本 + +测试密码中包含特殊字符(如@符号)的URL编码处理 +""" + +import sys +import os +from urllib.parse import quote_plus + +# 添加项目根目录到Python路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from database_manager import DatabaseManager +from config import DatabaseConfig + +def test_url_encoding(): + """ + 测试URL编码功能 + """ + print("🔧 URL编码测试") + print("=" * 50) + + # 测试密码编码 + test_passwords = [ + "sqlserver@7740", + "password@123", + "user#pass", + "test&password", + "simple123" + ] + + print("📋 密码编码测试:") + for password in test_passwords: + encoded = quote_plus(password) + print(f" 原始密码: {password}") + print(f" 编码后: {encoded}") + print() + + # 测试SQL Server连接URL构建 + print("🔗 SQL Server连接URL构建测试:") + + db_manager = DatabaseManager() + + # 获取配置 + config = DatabaseConfig.get_config("sqlserver") + + try: + # 构建连接URL + connection_url = db_manager._build_connection_url( + db_type="sqlserver", + host=config["host"], + port=config["port"], + username=config["username"], + password=config["password"], + database=config["database"] + ) + + print(f"✅ 连接URL构建成功:") + print(f" {connection_url}") + + # 验证URL中不包含原始的@符号(除了用户名密码分隔符) + if "sqlserver@7740" in connection_url: + print("❌ 错误: URL中仍包含未编码的密码") + return False + elif "sqlserver%407740" in connection_url: + print("✅ 正确: 密码已正确编码") + return True + else: + print("⚠️ 警告: 无法确定编码状态") + return True + + except Exception as e: + print(f"❌ 连接URL构建失败: {str(e)}") + return False + +def test_direct_connection(): + """ + 测试直接数据库连接 + """ + print("\n🔌 直接连接测试") + print("=" * 50) + + try: + # 获取配置 + config = DatabaseConfig.get_config("sqlserver") + + print(f"📋 连接配置:") + print(f" 主机: {config['host']}") + print(f" 端口: {config['port']}") + print(f" 数据库: {config['database']}") + print(f" 用户名: {config['username']}") + print(f" 密码: {'*' * len(config['password'])}") + + # 创建数据库管理器 + db_manager = DatabaseManager() + + # 尝试创建连接 + print("\n正在尝试连接...") + connection_id = db_manager.create_connection( + db_type="sqlserver", + host=config["host"], + port=config["port"], + username=config["username"], + password=config["password"], + database=config["database"] + ) + + print(f"✅ SQL Server连接成功! 连接ID: {connection_id}") + + # 测试查询 + try: + result = db_manager.execute_query(connection_id, "SELECT 1 as test_value") + print(f"✅ 查询测试成功: {result}") + + # 获取数据库版本 + version_result = db_manager.execute_query(connection_id, "SELECT @@VERSION as version") + if version_result: + version_info = version_result[0]['version'] + # 只显示版本信息的前100个字符 + print(f"📋 数据库版本: {version_info[:100]}...") + + except Exception as e: + print(f"⚠️ 查询测试失败: {str(e)}") + + # 关闭连接 + db_manager.close_connection(connection_id) + print("✅ 连接已关闭") + + return True + + except Exception as e: + print(f"❌ 连接测试失败: {str(e)}") + + # 分析错误类型 + error_str = str(e) + if "7740@192.168.11.200" in error_str: + print("\n🔍 错误分析: 密码中的@符号仍未正确处理") + print(" 建议检查URL编码逻辑") + elif "Unable to connect" in error_str: + print("\n🔍 错误分析: 无法连接到SQL Server") + print(" 可能原因:") + print(" 1. SQL Server服务未启动") + print(" 2. 网络连接问题") + print(" 3. 防火墙阻止连接") + print(" 4. 用户名或密码错误") + + return False + +def main(): + """ + 主函数 + """ + print("🧪 数据库连接URL编码修复验证") + print("=" * 60) + + # 1. URL编码测试 + url_test_success = test_url_encoding() + + # 2. 直接连接测试 + connection_test_success = test_direct_connection() + + # 总结 + print("\n" + "=" * 60) + print("📊 测试结果总结:") + print(f" URL编码: {'✅ 通过' if url_test_success else '❌ 失败'}") + print(f" 连接测试: {'✅ 通过' if connection_test_success else '❌ 失败'}") + + if url_test_success and connection_test_success: + print("\n🎉 所有测试通过! URL编码修复成功") + print(" 现在可以正常使用包含特殊字符的密码了") + elif url_test_success and not connection_test_success: + print("\n⚠️ URL编码修复成功,但连接仍有问题") + print(" 请检查SQL Server配置和网络连接") + else: + print("\n❌ 测试失败,需要进一步检查代码") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..17cf992 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +"""测试包初始化 + +用于放置单元测试与集成测试。 +""" +