first commit
This commit is contained in:
508
database_manager.py
Normal file
508
database_manager.py
Normal file
@@ -0,0 +1,508 @@
|
||||
from sqlalchemy import create_engine, text, MetaData, Table, Column, inspect
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
import oracledb
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库管理类,支持多种数据库类型的连接和操作"""
|
||||
|
||||
def __init__(self):
|
||||
self.engines = {}
|
||||
self.sessions = {}
|
||||
|
||||
def create_connection(self, db_type: str, host: str, port: int,
|
||||
username: str, password: str, database: str = None, **kwargs) -> str:
|
||||
"""创建数据库连接
|
||||
|
||||
Args:
|
||||
db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql)
|
||||
host: 数据库主机地址
|
||||
port: 数据库端口
|
||||
username: 用户名
|
||||
password: 密码
|
||||
database: 数据库名称
|
||||
|
||||
Returns:
|
||||
connection_id: 连接ID
|
||||
"""
|
||||
try:
|
||||
connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs)
|
||||
connection_id = f"{db_type}_{host}_{port}_{database or 'default'}"
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs = {"echo": False, "pool_pre_ping": True}
|
||||
if db_type.lower() == 'oracle':
|
||||
# Oracle连接池配置
|
||||
engine_kwargs.update({
|
||||
"pool_size": 5,
|
||||
"max_overflow": 10,
|
||||
"pool_timeout": 30,
|
||||
"pool_recycle": 3600,
|
||||
"pool_reset_on_return": "commit"
|
||||
})
|
||||
|
||||
engine = create_engine(connection_url, **engine_kwargs)
|
||||
|
||||
# 测试连接
|
||||
with engine.connect() as conn:
|
||||
if db_type.lower() == 'oracle':
|
||||
conn.execute(text("SELECT 1 FROM DUAL"))
|
||||
else:
|
||||
conn.execute(text("SELECT 1"))
|
||||
|
||||
self.engines[connection_id] = engine
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
self.sessions[connection_id] = SessionLocal
|
||||
|
||||
logger.info(f"成功创建数据库连接: {connection_id}")
|
||||
return connection_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库连接失败: {str(e)}")
|
||||
raise Exception(f"数据库连接失败: {str(e)}")
|
||||
|
||||
def test_connection(self, db_type: str, host: str, port: int,
|
||||
username: str, password: str, database: str = None, **kwargs) -> Dict[str, Any]:
|
||||
"""测试数据库是否可连通
|
||||
|
||||
Args:
|
||||
db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql)
|
||||
host: 数据库主机地址
|
||||
port: 数据库端口
|
||||
username: 用户名
|
||||
password: 密码
|
||||
database: 数据库名称
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 测试结果信息,包含连接是否成功与版本信息
|
||||
"""
|
||||
try:
|
||||
# 构建连接URL(内部会对Oracle进行预测试)
|
||||
connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs)
|
||||
engine = create_engine(connection_url, echo=False, pool_pre_ping=True)
|
||||
|
||||
server_version = None
|
||||
with engine.connect() as conn:
|
||||
# 基本连通性测试
|
||||
if db_type.lower() == 'oracle':
|
||||
conn.execute(text("SELECT 1 FROM DUAL"))
|
||||
version_sql = "SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1"
|
||||
server_version = conn.execute(text(version_sql)).scalar()
|
||||
elif db_type.lower() == 'mysql':
|
||||
conn.execute(text("SELECT 1"))
|
||||
server_version = conn.execute(text("SELECT VERSION()")).scalar()
|
||||
elif db_type.lower() == 'postgresql':
|
||||
conn.execute(text("SELECT 1"))
|
||||
server_version = conn.execute(text("SELECT version()")).scalar()
|
||||
elif db_type.lower() == 'sqlserver':
|
||||
conn.execute(text("SELECT 1"))
|
||||
server_version = conn.execute(text("SELECT @@VERSION")).scalar()
|
||||
else:
|
||||
raise ValueError(f"不支持的数据库类型: {db_type}")
|
||||
|
||||
# 释放临时引擎
|
||||
engine.dispose()
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"db_type": db_type,
|
||||
"connection_url": str(connection_url),
|
||||
"server_version": server_version
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"连接测试失败: {str(e)}")
|
||||
return {
|
||||
"ok": False,
|
||||
"db_type": db_type,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _build_connection_url(self, db_type: str, host: str, port: int,
|
||||
username: str, password: str, database: str = None, **kwargs) -> str:
|
||||
"""构建数据库连接URL"""
|
||||
# 对用户名和密码进行URL编码,防止特殊字符导致解析错误
|
||||
encoded_username = quote_plus(username)
|
||||
encoded_password = quote_plus(password)
|
||||
|
||||
if db_type.lower() == 'mysql':
|
||||
db_part = f"/{database}" if database else ""
|
||||
return f"mysql+pymysql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}?charset=utf8mb4"
|
||||
elif db_type.lower() == 'oracle':
|
||||
# Oracle连接格式 - 根据oracledb文档优化
|
||||
service_name = database or 'XE'
|
||||
|
||||
# 先测试Oracle连接以确保参数正确
|
||||
try:
|
||||
self._test_oracle_connection(
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
service_name=service_name,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Oracle连接预测试失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
# 使用简化的SQLAlchemy URL格式
|
||||
# 根据oracledb文档,SQLAlchemy会自动处理连接参数
|
||||
base_url = f"oracle+oracledb://{encoded_username}:{encoded_password}@{host}:{port}/?service_name={service_name}"
|
||||
|
||||
return base_url
|
||||
elif db_type.lower() == 'sqlserver':
|
||||
db_part = f"/{database}" if database else ""
|
||||
return f"mssql+pymssql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}"
|
||||
elif db_type.lower() == 'postgresql':
|
||||
db_part = f"/{database}" if database else "/postgres"
|
||||
return f"postgresql+psycopg2://{encoded_username}:{encoded_password}@{host}:{port}{db_part}"
|
||||
else:
|
||||
raise ValueError(f"不支持的数据库类型: {db_type}")
|
||||
|
||||
def _test_oracle_connection(self, host: str, port: int, username: str, password: str,
|
||||
service_name: str = None, **kwargs):
|
||||
"""测试Oracle直接连接"""
|
||||
service_name = service_name or 'XE'
|
||||
|
||||
# 尝试多种连接方式
|
||||
connection_methods = [
|
||||
# 方式1: Easy Connect字符串
|
||||
{
|
||||
'name': 'Easy Connect',
|
||||
'params': {
|
||||
'user': username,
|
||||
'password': password,
|
||||
'dsn': f"{host}:{port}/{service_name}"
|
||||
}
|
||||
},
|
||||
# 方式2: 分离参数
|
||||
{
|
||||
'name': 'Separate Parameters',
|
||||
'params': {
|
||||
'user': username,
|
||||
'password': password,
|
||||
'host': host,
|
||||
'port': port,
|
||||
'service_name': service_name
|
||||
}
|
||||
},
|
||||
# 方式3: 使用SID
|
||||
{
|
||||
'name': 'SID Connection',
|
||||
'params': {
|
||||
'user': username,
|
||||
'password': password,
|
||||
'dsn': oracledb.makedsn(host, port, sid=service_name)
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
last_error = None
|
||||
|
||||
for method in connection_methods:
|
||||
try:
|
||||
logger.info(f"尝试Oracle连接方式: {method['name']}")
|
||||
|
||||
# 尝试连接
|
||||
connection = oracledb.connect(**method['params'])
|
||||
|
||||
# 测试查询
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("SELECT 1 FROM DUAL")
|
||||
result = cursor.fetchone()
|
||||
|
||||
# 获取数据库版本信息
|
||||
cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1")
|
||||
version = cursor.fetchone()
|
||||
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
logger.info(f"Oracle连接成功 ({method['name']}): 查询结果={result}, 版本={version[0] if version else 'Unknown'}")
|
||||
return # 连接成功,退出函数
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Oracle连接方式 '{method['name']}' 失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 所有连接方式都失败
|
||||
error_msg = f"所有Oracle连接方式都失败。最后一个错误: {str(last_error)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def get_engine(self, connection_id: str):
|
||||
"""获取数据库引擎"""
|
||||
if connection_id not in self.engines:
|
||||
raise ValueError(f"连接ID不存在: {connection_id}")
|
||||
return self.engines[connection_id]
|
||||
|
||||
def get_session(self, connection_id: str):
|
||||
"""获取数据库会话"""
|
||||
if connection_id not in self.sessions:
|
||||
raise ValueError(f"连接ID不存在: {connection_id}")
|
||||
return self.sessions[connection_id]()
|
||||
|
||||
def execute_query(self, connection_id: str, sql: str, params: Dict = None) -> List[Dict]:
|
||||
"""执行查询SQL"""
|
||||
try:
|
||||
engine = self.get_engine(connection_id)
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(sql), params or {})
|
||||
columns = result.keys()
|
||||
rows = result.fetchall()
|
||||
return [dict(zip(columns, row)) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"执行查询失败: {str(e)}")
|
||||
raise Exception(f"查询执行失败: {str(e)}")
|
||||
|
||||
def execute_non_query(self, connection_id: str, sql: str, params: Dict = None) -> int:
|
||||
"""执行非查询SQL(INSERT, UPDATE, DELETE)"""
|
||||
try:
|
||||
engine = self.get_engine(connection_id)
|
||||
with engine.connect() as conn:
|
||||
with conn.begin():
|
||||
result = conn.execute(text(sql), params or {})
|
||||
return result.rowcount
|
||||
except Exception as e:
|
||||
logger.error(f"执行非查询失败: {str(e)}")
|
||||
raise Exception(f"非查询执行失败: {str(e)}")
|
||||
|
||||
def get_database_info(self, connection_id: str) -> Dict:
|
||||
"""获取数据库信息"""
|
||||
try:
|
||||
engine = self.get_engine(connection_id)
|
||||
inspector = inspect(engine)
|
||||
|
||||
# 获取数据库名称
|
||||
with engine.connect() as conn:
|
||||
db_name_result = conn.execute(text("SELECT DATABASE()" if 'mysql' in str(engine.url)
|
||||
else "SELECT CURRENT_DATABASE()" if 'postgresql' in str(engine.url)
|
||||
else "SELECT DB_NAME()" if 'mssql' in str(engine.url)
|
||||
else "SELECT SYS_CONTEXT('USERENV', 'DB_NAME') FROM DUAL"))
|
||||
db_name = db_name_result.scalar()
|
||||
|
||||
# 获取表列表
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
return {
|
||||
"database_name": db_name,
|
||||
"tables": tables,
|
||||
"table_count": len(tables)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库信息失败: {str(e)}")
|
||||
raise Exception(f"获取数据库信息失败: {str(e)}")
|
||||
|
||||
def get_tables_with_comments(self, connection_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取数据库中所有表及其备注信息"""
|
||||
try:
|
||||
engine = self.get_engine(connection_id)
|
||||
url_str = str(engine.url)
|
||||
with engine.connect() as conn:
|
||||
if 'mysql' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
TABLE_NAME AS table_name,
|
||||
TABLE_COMMENT AS table_comment
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
ORDER BY TABLE_NAME
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
|
||||
elif 'postgresql' in url_str:
|
||||
# 获取当前schema
|
||||
schema = conn.execute(text("SELECT current_schema()")).scalar() or 'public'
|
||||
sql = """
|
||||
SELECT
|
||||
c.relname AS table_name,
|
||||
obj_description(c.oid) AS table_comment
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE n.nspname = :schema
|
||||
AND c.relkind IN ('r','p')
|
||||
ORDER BY c.relname
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql), {"schema": schema}).mappings().all()]
|
||||
elif 'mssql' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
t.name AS table_name,
|
||||
CAST(ep.value AS NVARCHAR(MAX)) AS table_comment
|
||||
FROM sys.tables t
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = t.object_id
|
||||
AND ep.minor_id = 0
|
||||
AND ep.name = 'MS_Description'
|
||||
ORDER BY t.name
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
|
||||
elif 'oracle' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
table_name AS table_name,
|
||||
comments AS table_comment
|
||||
FROM user_tab_comments
|
||||
ORDER BY table_name
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
|
||||
else:
|
||||
raise ValueError("不支持的数据库类型")
|
||||
except Exception as e:
|
||||
logger.error(f"获取表备注信息失败: {str(e)}")
|
||||
raise Exception(f"获取表备注信息失败: {str(e)}")
|
||||
|
||||
def get_table_info(self, connection_id: str, table_name: str) -> Dict:
|
||||
"""获取表信息
|
||||
|
||||
序列化说明:
|
||||
- SQLAlchemy 的列类型对象(如 `VARCHAR`、`NUMBER` 等)不可直接JSON序列化。
|
||||
- 本方法将列的 `type` 字段统一转换为字符串表示(例如 `VARCHAR(255)`)。
|
||||
- 其他由 Inspector 返回的结构(主键、外键、索引)保持原始可序列化格式。
|
||||
"""
|
||||
try:
|
||||
engine = self.get_engine(connection_id)
|
||||
inspector = inspect(engine)
|
||||
|
||||
# 获取列信息
|
||||
raw_columns = inspector.get_columns(table_name)
|
||||
columns = []
|
||||
for col in raw_columns:
|
||||
# 创建可序列化的列字典
|
||||
serializable = dict(col)
|
||||
col_type = serializable.get('type')
|
||||
# 将SQLAlchemy类型对象转换为字符串
|
||||
if col_type is not None:
|
||||
try:
|
||||
serializable['type'] = str(col_type)
|
||||
except Exception:
|
||||
# 兜底:使用类型名
|
||||
serializable['type'] = type(col_type).__name__
|
||||
columns.append(serializable)
|
||||
|
||||
# 获取主键
|
||||
primary_keys = inspector.get_pk_constraint(table_name)
|
||||
|
||||
# 获取外键
|
||||
foreign_keys = inspector.get_foreign_keys(table_name)
|
||||
|
||||
# 获取索引
|
||||
indexes = inspector.get_indexes(table_name)
|
||||
|
||||
return {
|
||||
"table_name": table_name,
|
||||
"columns": columns,
|
||||
"primary_keys": primary_keys,
|
||||
"foreign_keys": foreign_keys,
|
||||
"indexes": indexes
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取表信息失败: {str(e)}")
|
||||
raise Exception(f"获取表信息失败: {str(e)}")
|
||||
|
||||
def get_table_columns(self, connection_id: str, table_name: str) -> List[Dict[str, Any]]:
|
||||
"""获取指定表的字段名、类型、备注等信息(兼容多数据库)"""
|
||||
try:
|
||||
engine = self.get_engine(connection_id)
|
||||
url_str = str(engine.url)
|
||||
with engine.connect() as conn:
|
||||
if 'mysql' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
COLUMN_NAME AS column_name,
|
||||
DATA_TYPE AS data_type,
|
||||
IS_NULLABLE AS is_nullable,
|
||||
COLUMN_DEFAULT AS column_default,
|
||||
COLUMN_COMMENT AS column_comment,
|
||||
CHARACTER_MAXIMUM_LENGTH AS max_length,
|
||||
NUMERIC_PRECISION AS numeric_precision,
|
||||
NUMERIC_SCALE AS numeric_scale
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = :table_name
|
||||
ORDER BY ORDINAL_POSITION
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
|
||||
elif 'postgresql' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
||||
NOT a.attnotnull AS is_nullable,
|
||||
pg_get_expr(d.adbin, d.adrelid) AS column_default,
|
||||
col_description(a.attrelid, a.attnum) AS column_comment
|
||||
FROM pg_attribute a
|
||||
JOIN pg_class c ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
|
||||
WHERE c.relname = :table_name
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
ORDER BY a.attnum
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
|
||||
elif 'mssql' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
c.name AS column_name,
|
||||
t.name AS data_type,
|
||||
CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable,
|
||||
OBJECT_DEFINITION(c.default_object_id) AS column_default,
|
||||
CAST(ep.value AS NVARCHAR(MAX)) AS column_comment,
|
||||
COLUMNPROPERTY(c.object_id, c.name, 'charmaxlen') AS max_length
|
||||
FROM sys.columns c
|
||||
LEFT JOIN sys.types t ON c.user_type_id = t.user_type_id
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = c.object_id
|
||||
AND ep.minor_id = c.column_id
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE c.object_id = OBJECT_ID(:table_name)
|
||||
ORDER BY c.column_id
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
|
||||
elif 'oracle' in url_str:
|
||||
sql = """
|
||||
SELECT
|
||||
cols.column_name AS column_name,
|
||||
cols.data_type AS data_type,
|
||||
cols.nullable AS is_nullable,
|
||||
cols.data_default AS column_default,
|
||||
comm.comments AS column_comment,
|
||||
cols.data_length AS max_length
|
||||
FROM user_tab_columns cols
|
||||
LEFT JOIN user_col_comments comm
|
||||
ON comm.table_name = cols.table_name
|
||||
AND comm.column_name = cols.column_name
|
||||
WHERE cols.table_name = UPPER(:table_name)
|
||||
ORDER BY cols.column_id
|
||||
"""
|
||||
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
|
||||
else:
|
||||
raise ValueError("不支持的数据库类型")
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段信息失败: {str(e)}")
|
||||
raise Exception(f"获取字段信息失败: {str(e)}")
|
||||
|
||||
def close_connection(self, connection_id: str):
|
||||
"""关闭数据库连接"""
|
||||
try:
|
||||
if connection_id in self.engines:
|
||||
self.engines[connection_id].dispose()
|
||||
del self.engines[connection_id]
|
||||
del self.sessions[connection_id]
|
||||
logger.info(f"已关闭数据库连接: {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭连接失败: {str(e)}")
|
||||
|
||||
def list_connections(self) -> List[str]:
|
||||
"""列出所有活动连接"""
|
||||
return list(self.engines.keys())
|
||||
|
||||
# 全局数据库管理器实例
|
||||
db_manager = DatabaseManager()
|
||||
Reference in New Issue
Block a user