first commit

This commit is contained in:
2026-03-04 12:17:52 +08:00
commit ecb3e1d9b2
42 changed files with 4081 additions and 0 deletions

508
database_manager.py Normal file
View File

@@ -0,0 +1,508 @@
from sqlalchemy import create_engine, text, MetaData, Table, Column, inspect
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import SQLAlchemyError
from typing import Dict, List, Any, Optional
import logging
import oracledb
from urllib.parse import quote_plus
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理类,支持多种数据库类型的连接和操作"""
def __init__(self):
self.engines = {}
self.sessions = {}
def create_connection(self, db_type: str, host: str, port: int,
username: str, password: str, database: str = None, **kwargs) -> str:
"""创建数据库连接
Args:
db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql)
host: 数据库主机地址
port: 数据库端口
username: 用户名
password: 密码
database: 数据库名称
Returns:
connection_id: 连接ID
"""
try:
connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs)
connection_id = f"{db_type}_{host}_{port}_{database or 'default'}"
# 配置引擎参数
engine_kwargs = {"echo": False, "pool_pre_ping": True}
if db_type.lower() == 'oracle':
# Oracle连接池配置
engine_kwargs.update({
"pool_size": 5,
"max_overflow": 10,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_reset_on_return": "commit"
})
engine = create_engine(connection_url, **engine_kwargs)
# 测试连接
with engine.connect() as conn:
if db_type.lower() == 'oracle':
conn.execute(text("SELECT 1 FROM DUAL"))
else:
conn.execute(text("SELECT 1"))
self.engines[connection_id] = engine
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
self.sessions[connection_id] = SessionLocal
logger.info(f"成功创建数据库连接: {connection_id}")
return connection_id
except Exception as e:
logger.error(f"创建数据库连接失败: {str(e)}")
raise Exception(f"数据库连接失败: {str(e)}")
def test_connection(self, db_type: str, host: str, port: int,
username: str, password: str, database: str = None, **kwargs) -> Dict[str, Any]:
"""测试数据库是否可连通
Args:
db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql)
host: 数据库主机地址
port: 数据库端口
username: 用户名
password: 密码
database: 数据库名称
Returns:
Dict[str, Any]: 测试结果信息,包含连接是否成功与版本信息
"""
try:
# 构建连接URL内部会对Oracle进行预测试
connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs)
engine = create_engine(connection_url, echo=False, pool_pre_ping=True)
server_version = None
with engine.connect() as conn:
# 基本连通性测试
if db_type.lower() == 'oracle':
conn.execute(text("SELECT 1 FROM DUAL"))
version_sql = "SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1"
server_version = conn.execute(text(version_sql)).scalar()
elif db_type.lower() == 'mysql':
conn.execute(text("SELECT 1"))
server_version = conn.execute(text("SELECT VERSION()")).scalar()
elif db_type.lower() == 'postgresql':
conn.execute(text("SELECT 1"))
server_version = conn.execute(text("SELECT version()")).scalar()
elif db_type.lower() == 'sqlserver':
conn.execute(text("SELECT 1"))
server_version = conn.execute(text("SELECT @@VERSION")).scalar()
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
# 释放临时引擎
engine.dispose()
return {
"ok": True,
"db_type": db_type,
"connection_url": str(connection_url),
"server_version": server_version
}
except Exception as e:
logger.error(f"连接测试失败: {str(e)}")
return {
"ok": False,
"db_type": db_type,
"error": str(e)
}
def _build_connection_url(self, db_type: str, host: str, port: int,
username: str, password: str, database: str = None, **kwargs) -> str:
"""构建数据库连接URL"""
# 对用户名和密码进行URL编码防止特殊字符导致解析错误
encoded_username = quote_plus(username)
encoded_password = quote_plus(password)
if db_type.lower() == 'mysql':
db_part = f"/{database}" if database else ""
return f"mysql+pymysql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}?charset=utf8mb4"
elif db_type.lower() == 'oracle':
# Oracle连接格式 - 根据oracledb文档优化
service_name = database or 'XE'
# 先测试Oracle连接以确保参数正确
try:
self._test_oracle_connection(
host=host,
port=port,
username=username,
password=password,
service_name=service_name,
**kwargs
)
except Exception as e:
logger.error(f"Oracle连接预测试失败: {str(e)}")
raise e
# 使用简化的SQLAlchemy URL格式
# 根据oracledb文档SQLAlchemy会自动处理连接参数
base_url = f"oracle+oracledb://{encoded_username}:{encoded_password}@{host}:{port}/?service_name={service_name}"
return base_url
elif db_type.lower() == 'sqlserver':
db_part = f"/{database}" if database else ""
return f"mssql+pymssql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}"
elif db_type.lower() == 'postgresql':
db_part = f"/{database}" if database else "/postgres"
return f"postgresql+psycopg2://{encoded_username}:{encoded_password}@{host}:{port}{db_part}"
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
def _test_oracle_connection(self, host: str, port: int, username: str, password: str,
service_name: str = None, **kwargs):
"""测试Oracle直接连接"""
service_name = service_name or 'XE'
# 尝试多种连接方式
connection_methods = [
# 方式1: Easy Connect字符串
{
'name': 'Easy Connect',
'params': {
'user': username,
'password': password,
'dsn': f"{host}:{port}/{service_name}"
}
},
# 方式2: 分离参数
{
'name': 'Separate Parameters',
'params': {
'user': username,
'password': password,
'host': host,
'port': port,
'service_name': service_name
}
},
# 方式3: 使用SID
{
'name': 'SID Connection',
'params': {
'user': username,
'password': password,
'dsn': oracledb.makedsn(host, port, sid=service_name)
}
}
]
last_error = None
for method in connection_methods:
try:
logger.info(f"尝试Oracle连接方式: {method['name']}")
# 尝试连接
connection = oracledb.connect(**method['params'])
# 测试查询
cursor = connection.cursor()
cursor.execute("SELECT 1 FROM DUAL")
result = cursor.fetchone()
# 获取数据库版本信息
cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1")
version = cursor.fetchone()
cursor.close()
connection.close()
logger.info(f"Oracle连接成功 ({method['name']}): 查询结果={result}, 版本={version[0] if version else 'Unknown'}")
return # 连接成功,退出函数
except Exception as e:
last_error = e
logger.warning(f"Oracle连接方式 '{method['name']}' 失败: {str(e)}")
continue
# 所有连接方式都失败
error_msg = f"所有Oracle连接方式都失败。最后一个错误: {str(last_error)}"
logger.error(error_msg)
raise Exception(error_msg)
def get_engine(self, connection_id: str):
"""获取数据库引擎"""
if connection_id not in self.engines:
raise ValueError(f"连接ID不存在: {connection_id}")
return self.engines[connection_id]
def get_session(self, connection_id: str):
"""获取数据库会话"""
if connection_id not in self.sessions:
raise ValueError(f"连接ID不存在: {connection_id}")
return self.sessions[connection_id]()
def execute_query(self, connection_id: str, sql: str, params: Dict = None) -> List[Dict]:
"""执行查询SQL"""
try:
engine = self.get_engine(connection_id)
with engine.connect() as conn:
result = conn.execute(text(sql), params or {})
columns = result.keys()
rows = result.fetchall()
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
logger.error(f"执行查询失败: {str(e)}")
raise Exception(f"查询执行失败: {str(e)}")
def execute_non_query(self, connection_id: str, sql: str, params: Dict = None) -> int:
"""执行非查询SQLINSERT, UPDATE, DELETE"""
try:
engine = self.get_engine(connection_id)
with engine.connect() as conn:
with conn.begin():
result = conn.execute(text(sql), params or {})
return result.rowcount
except Exception as e:
logger.error(f"执行非查询失败: {str(e)}")
raise Exception(f"非查询执行失败: {str(e)}")
def get_database_info(self, connection_id: str) -> Dict:
"""获取数据库信息"""
try:
engine = self.get_engine(connection_id)
inspector = inspect(engine)
# 获取数据库名称
with engine.connect() as conn:
db_name_result = conn.execute(text("SELECT DATABASE()" if 'mysql' in str(engine.url)
else "SELECT CURRENT_DATABASE()" if 'postgresql' in str(engine.url)
else "SELECT DB_NAME()" if 'mssql' in str(engine.url)
else "SELECT SYS_CONTEXT('USERENV', 'DB_NAME') FROM DUAL"))
db_name = db_name_result.scalar()
# 获取表列表
tables = inspector.get_table_names()
return {
"database_name": db_name,
"tables": tables,
"table_count": len(tables)
}
except Exception as e:
logger.error(f"获取数据库信息失败: {str(e)}")
raise Exception(f"获取数据库信息失败: {str(e)}")
def get_tables_with_comments(self, connection_id: str) -> List[Dict[str, Any]]:
"""获取数据库中所有表及其备注信息"""
try:
engine = self.get_engine(connection_id)
url_str = str(engine.url)
with engine.connect() as conn:
if 'mysql' in url_str:
sql = """
SELECT
TABLE_NAME AS table_name,
TABLE_COMMENT AS table_comment
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE()
ORDER BY TABLE_NAME
"""
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
elif 'postgresql' in url_str:
# 获取当前schema
schema = conn.execute(text("SELECT current_schema()")).scalar() or 'public'
sql = """
SELECT
c.relname AS table_name,
obj_description(c.oid) AS table_comment
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND c.relkind IN ('r','p')
ORDER BY c.relname
"""
return [dict(row) for row in conn.execute(text(sql), {"schema": schema}).mappings().all()]
elif 'mssql' in url_str:
sql = """
SELECT
t.name AS table_name,
CAST(ep.value AS NVARCHAR(MAX)) AS table_comment
FROM sys.tables t
LEFT JOIN sys.extended_properties ep
ON ep.major_id = t.object_id
AND ep.minor_id = 0
AND ep.name = 'MS_Description'
ORDER BY t.name
"""
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
elif 'oracle' in url_str:
sql = """
SELECT
table_name AS table_name,
comments AS table_comment
FROM user_tab_comments
ORDER BY table_name
"""
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
else:
raise ValueError("不支持的数据库类型")
except Exception as e:
logger.error(f"获取表备注信息失败: {str(e)}")
raise Exception(f"获取表备注信息失败: {str(e)}")
def get_table_info(self, connection_id: str, table_name: str) -> Dict:
"""获取表信息
序列化说明:
- SQLAlchemy 的列类型对象(如 `VARCHAR`、`NUMBER` 等不可直接JSON序列化。
- 本方法将列的 `type` 字段统一转换为字符串表示(例如 `VARCHAR(255)`)。
- 其他由 Inspector 返回的结构(主键、外键、索引)保持原始可序列化格式。
"""
try:
engine = self.get_engine(connection_id)
inspector = inspect(engine)
# 获取列信息
raw_columns = inspector.get_columns(table_name)
columns = []
for col in raw_columns:
# 创建可序列化的列字典
serializable = dict(col)
col_type = serializable.get('type')
# 将SQLAlchemy类型对象转换为字符串
if col_type is not None:
try:
serializable['type'] = str(col_type)
except Exception:
# 兜底:使用类型名
serializable['type'] = type(col_type).__name__
columns.append(serializable)
# 获取主键
primary_keys = inspector.get_pk_constraint(table_name)
# 获取外键
foreign_keys = inspector.get_foreign_keys(table_name)
# 获取索引
indexes = inspector.get_indexes(table_name)
return {
"table_name": table_name,
"columns": columns,
"primary_keys": primary_keys,
"foreign_keys": foreign_keys,
"indexes": indexes
}
except Exception as e:
logger.error(f"获取表信息失败: {str(e)}")
raise Exception(f"获取表信息失败: {str(e)}")
def get_table_columns(self, connection_id: str, table_name: str) -> List[Dict[str, Any]]:
"""获取指定表的字段名、类型、备注等信息(兼容多数据库)"""
try:
engine = self.get_engine(connection_id)
url_str = str(engine.url)
with engine.connect() as conn:
if 'mysql' in url_str:
sql = """
SELECT
COLUMN_NAME AS column_name,
DATA_TYPE AS data_type,
IS_NULLABLE AS is_nullable,
COLUMN_DEFAULT AS column_default,
COLUMN_COMMENT AS column_comment,
CHARACTER_MAXIMUM_LENGTH AS max_length,
NUMERIC_PRECISION AS numeric_precision,
NUMERIC_SCALE AS numeric_scale
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
ORDER BY ORDINAL_POSITION
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
elif 'postgresql' in url_str:
sql = """
SELECT
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
NOT a.attnotnull AS is_nullable,
pg_get_expr(d.adbin, d.adrelid) AS column_default,
col_description(a.attrelid, a.attnum) AS column_comment
FROM pg_attribute a
JOIN pg_class c ON a.attrelid = c.oid
LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
WHERE c.relname = :table_name
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
elif 'mssql' in url_str:
sql = """
SELECT
c.name AS column_name,
t.name AS data_type,
CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable,
OBJECT_DEFINITION(c.default_object_id) AS column_default,
CAST(ep.value AS NVARCHAR(MAX)) AS column_comment,
COLUMNPROPERTY(c.object_id, c.name, 'charmaxlen') AS max_length
FROM sys.columns c
LEFT JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.extended_properties ep
ON ep.major_id = c.object_id
AND ep.minor_id = c.column_id
AND ep.name = 'MS_Description'
WHERE c.object_id = OBJECT_ID(:table_name)
ORDER BY c.column_id
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
elif 'oracle' in url_str:
sql = """
SELECT
cols.column_name AS column_name,
cols.data_type AS data_type,
cols.nullable AS is_nullable,
cols.data_default AS column_default,
comm.comments AS column_comment,
cols.data_length AS max_length
FROM user_tab_columns cols
LEFT JOIN user_col_comments comm
ON comm.table_name = cols.table_name
AND comm.column_name = cols.column_name
WHERE cols.table_name = UPPER(:table_name)
ORDER BY cols.column_id
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
else:
raise ValueError("不支持的数据库类型")
except Exception as e:
logger.error(f"获取字段信息失败: {str(e)}")
raise Exception(f"获取字段信息失败: {str(e)}")
def close_connection(self, connection_id: str):
"""关闭数据库连接"""
try:
if connection_id in self.engines:
self.engines[connection_id].dispose()
del self.engines[connection_id]
del self.sessions[connection_id]
logger.info(f"已关闭数据库连接: {connection_id}")
except Exception as e:
logger.error(f"关闭连接失败: {str(e)}")
def list_connections(self) -> List[str]:
"""列出所有活动连接"""
return list(self.engines.keys())
# 全局数据库管理器实例
db_manager = DatabaseManager()