Files
dababase-etl-python/database_manager.py
2026-03-04 12:17:52 +08:00

509 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy import create_engine, text, MetaData, Table, Column, inspect
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import SQLAlchemyError
from typing import Dict, List, Any, Optional
import logging
import oracledb
from urllib.parse import quote_plus
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理类,支持多种数据库类型的连接和操作"""
def __init__(self):
self.engines = {}
self.sessions = {}
def create_connection(self, db_type: str, host: str, port: int,
username: str, password: str, database: str = None, **kwargs) -> str:
"""创建数据库连接
Args:
db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql)
host: 数据库主机地址
port: 数据库端口
username: 用户名
password: 密码
database: 数据库名称
Returns:
connection_id: 连接ID
"""
try:
connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs)
connection_id = f"{db_type}_{host}_{port}_{database or 'default'}"
# 配置引擎参数
engine_kwargs = {"echo": False, "pool_pre_ping": True}
if db_type.lower() == 'oracle':
# Oracle连接池配置
engine_kwargs.update({
"pool_size": 5,
"max_overflow": 10,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_reset_on_return": "commit"
})
engine = create_engine(connection_url, **engine_kwargs)
# 测试连接
with engine.connect() as conn:
if db_type.lower() == 'oracle':
conn.execute(text("SELECT 1 FROM DUAL"))
else:
conn.execute(text("SELECT 1"))
self.engines[connection_id] = engine
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
self.sessions[connection_id] = SessionLocal
logger.info(f"成功创建数据库连接: {connection_id}")
return connection_id
except Exception as e:
logger.error(f"创建数据库连接失败: {str(e)}")
raise Exception(f"数据库连接失败: {str(e)}")
def test_connection(self, db_type: str, host: str, port: int,
username: str, password: str, database: str = None, **kwargs) -> Dict[str, Any]:
"""测试数据库是否可连通
Args:
db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql)
host: 数据库主机地址
port: 数据库端口
username: 用户名
password: 密码
database: 数据库名称
Returns:
Dict[str, Any]: 测试结果信息,包含连接是否成功与版本信息
"""
try:
# 构建连接URL内部会对Oracle进行预测试
connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs)
engine = create_engine(connection_url, echo=False, pool_pre_ping=True)
server_version = None
with engine.connect() as conn:
# 基本连通性测试
if db_type.lower() == 'oracle':
conn.execute(text("SELECT 1 FROM DUAL"))
version_sql = "SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1"
server_version = conn.execute(text(version_sql)).scalar()
elif db_type.lower() == 'mysql':
conn.execute(text("SELECT 1"))
server_version = conn.execute(text("SELECT VERSION()")).scalar()
elif db_type.lower() == 'postgresql':
conn.execute(text("SELECT 1"))
server_version = conn.execute(text("SELECT version()")).scalar()
elif db_type.lower() == 'sqlserver':
conn.execute(text("SELECT 1"))
server_version = conn.execute(text("SELECT @@VERSION")).scalar()
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
# 释放临时引擎
engine.dispose()
return {
"ok": True,
"db_type": db_type,
"connection_url": str(connection_url),
"server_version": server_version
}
except Exception as e:
logger.error(f"连接测试失败: {str(e)}")
return {
"ok": False,
"db_type": db_type,
"error": str(e)
}
def _build_connection_url(self, db_type: str, host: str, port: int,
username: str, password: str, database: str = None, **kwargs) -> str:
"""构建数据库连接URL"""
# 对用户名和密码进行URL编码防止特殊字符导致解析错误
encoded_username = quote_plus(username)
encoded_password = quote_plus(password)
if db_type.lower() == 'mysql':
db_part = f"/{database}" if database else ""
return f"mysql+pymysql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}?charset=utf8mb4"
elif db_type.lower() == 'oracle':
# Oracle连接格式 - 根据oracledb文档优化
service_name = database or 'XE'
# 先测试Oracle连接以确保参数正确
try:
self._test_oracle_connection(
host=host,
port=port,
username=username,
password=password,
service_name=service_name,
**kwargs
)
except Exception as e:
logger.error(f"Oracle连接预测试失败: {str(e)}")
raise e
# 使用简化的SQLAlchemy URL格式
# 根据oracledb文档SQLAlchemy会自动处理连接参数
base_url = f"oracle+oracledb://{encoded_username}:{encoded_password}@{host}:{port}/?service_name={service_name}"
return base_url
elif db_type.lower() == 'sqlserver':
db_part = f"/{database}" if database else ""
return f"mssql+pymssql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}"
elif db_type.lower() == 'postgresql':
db_part = f"/{database}" if database else "/postgres"
return f"postgresql+psycopg2://{encoded_username}:{encoded_password}@{host}:{port}{db_part}"
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
def _test_oracle_connection(self, host: str, port: int, username: str, password: str,
service_name: str = None, **kwargs):
"""测试Oracle直接连接"""
service_name = service_name or 'XE'
# 尝试多种连接方式
connection_methods = [
# 方式1: Easy Connect字符串
{
'name': 'Easy Connect',
'params': {
'user': username,
'password': password,
'dsn': f"{host}:{port}/{service_name}"
}
},
# 方式2: 分离参数
{
'name': 'Separate Parameters',
'params': {
'user': username,
'password': password,
'host': host,
'port': port,
'service_name': service_name
}
},
# 方式3: 使用SID
{
'name': 'SID Connection',
'params': {
'user': username,
'password': password,
'dsn': oracledb.makedsn(host, port, sid=service_name)
}
}
]
last_error = None
for method in connection_methods:
try:
logger.info(f"尝试Oracle连接方式: {method['name']}")
# 尝试连接
connection = oracledb.connect(**method['params'])
# 测试查询
cursor = connection.cursor()
cursor.execute("SELECT 1 FROM DUAL")
result = cursor.fetchone()
# 获取数据库版本信息
cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1")
version = cursor.fetchone()
cursor.close()
connection.close()
logger.info(f"Oracle连接成功 ({method['name']}): 查询结果={result}, 版本={version[0] if version else 'Unknown'}")
return # 连接成功,退出函数
except Exception as e:
last_error = e
logger.warning(f"Oracle连接方式 '{method['name']}' 失败: {str(e)}")
continue
# 所有连接方式都失败
error_msg = f"所有Oracle连接方式都失败。最后一个错误: {str(last_error)}"
logger.error(error_msg)
raise Exception(error_msg)
def get_engine(self, connection_id: str):
"""获取数据库引擎"""
if connection_id not in self.engines:
raise ValueError(f"连接ID不存在: {connection_id}")
return self.engines[connection_id]
def get_session(self, connection_id: str):
"""获取数据库会话"""
if connection_id not in self.sessions:
raise ValueError(f"连接ID不存在: {connection_id}")
return self.sessions[connection_id]()
def execute_query(self, connection_id: str, sql: str, params: Dict = None) -> List[Dict]:
"""执行查询SQL"""
try:
engine = self.get_engine(connection_id)
with engine.connect() as conn:
result = conn.execute(text(sql), params or {})
columns = result.keys()
rows = result.fetchall()
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
logger.error(f"执行查询失败: {str(e)}")
raise Exception(f"查询执行失败: {str(e)}")
def execute_non_query(self, connection_id: str, sql: str, params: Dict = None) -> int:
"""执行非查询SQLINSERT, UPDATE, DELETE"""
try:
engine = self.get_engine(connection_id)
with engine.connect() as conn:
with conn.begin():
result = conn.execute(text(sql), params or {})
return result.rowcount
except Exception as e:
logger.error(f"执行非查询失败: {str(e)}")
raise Exception(f"非查询执行失败: {str(e)}")
def get_database_info(self, connection_id: str) -> Dict:
"""获取数据库信息"""
try:
engine = self.get_engine(connection_id)
inspector = inspect(engine)
# 获取数据库名称
with engine.connect() as conn:
db_name_result = conn.execute(text("SELECT DATABASE()" if 'mysql' in str(engine.url)
else "SELECT CURRENT_DATABASE()" if 'postgresql' in str(engine.url)
else "SELECT DB_NAME()" if 'mssql' in str(engine.url)
else "SELECT SYS_CONTEXT('USERENV', 'DB_NAME') FROM DUAL"))
db_name = db_name_result.scalar()
# 获取表列表
tables = inspector.get_table_names()
return {
"database_name": db_name,
"tables": tables,
"table_count": len(tables)
}
except Exception as e:
logger.error(f"获取数据库信息失败: {str(e)}")
raise Exception(f"获取数据库信息失败: {str(e)}")
def get_tables_with_comments(self, connection_id: str) -> List[Dict[str, Any]]:
"""获取数据库中所有表及其备注信息"""
try:
engine = self.get_engine(connection_id)
url_str = str(engine.url)
with engine.connect() as conn:
if 'mysql' in url_str:
sql = """
SELECT
TABLE_NAME AS table_name,
TABLE_COMMENT AS table_comment
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE()
ORDER BY TABLE_NAME
"""
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
elif 'postgresql' in url_str:
# 获取当前schema
schema = conn.execute(text("SELECT current_schema()")).scalar() or 'public'
sql = """
SELECT
c.relname AS table_name,
obj_description(c.oid) AS table_comment
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND c.relkind IN ('r','p')
ORDER BY c.relname
"""
return [dict(row) for row in conn.execute(text(sql), {"schema": schema}).mappings().all()]
elif 'mssql' in url_str:
sql = """
SELECT
t.name AS table_name,
CAST(ep.value AS NVARCHAR(MAX)) AS table_comment
FROM sys.tables t
LEFT JOIN sys.extended_properties ep
ON ep.major_id = t.object_id
AND ep.minor_id = 0
AND ep.name = 'MS_Description'
ORDER BY t.name
"""
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
elif 'oracle' in url_str:
sql = """
SELECT
table_name AS table_name,
comments AS table_comment
FROM user_tab_comments
ORDER BY table_name
"""
return [dict(row) for row in conn.execute(text(sql)).mappings().all()]
else:
raise ValueError("不支持的数据库类型")
except Exception as e:
logger.error(f"获取表备注信息失败: {str(e)}")
raise Exception(f"获取表备注信息失败: {str(e)}")
def get_table_info(self, connection_id: str, table_name: str) -> Dict:
"""获取表信息
序列化说明:
- SQLAlchemy 的列类型对象(如 `VARCHAR`、`NUMBER` 等不可直接JSON序列化。
- 本方法将列的 `type` 字段统一转换为字符串表示(例如 `VARCHAR(255)`)。
- 其他由 Inspector 返回的结构(主键、外键、索引)保持原始可序列化格式。
"""
try:
engine = self.get_engine(connection_id)
inspector = inspect(engine)
# 获取列信息
raw_columns = inspector.get_columns(table_name)
columns = []
for col in raw_columns:
# 创建可序列化的列字典
serializable = dict(col)
col_type = serializable.get('type')
# 将SQLAlchemy类型对象转换为字符串
if col_type is not None:
try:
serializable['type'] = str(col_type)
except Exception:
# 兜底:使用类型名
serializable['type'] = type(col_type).__name__
columns.append(serializable)
# 获取主键
primary_keys = inspector.get_pk_constraint(table_name)
# 获取外键
foreign_keys = inspector.get_foreign_keys(table_name)
# 获取索引
indexes = inspector.get_indexes(table_name)
return {
"table_name": table_name,
"columns": columns,
"primary_keys": primary_keys,
"foreign_keys": foreign_keys,
"indexes": indexes
}
except Exception as e:
logger.error(f"获取表信息失败: {str(e)}")
raise Exception(f"获取表信息失败: {str(e)}")
def get_table_columns(self, connection_id: str, table_name: str) -> List[Dict[str, Any]]:
"""获取指定表的字段名、类型、备注等信息(兼容多数据库)"""
try:
engine = self.get_engine(connection_id)
url_str = str(engine.url)
with engine.connect() as conn:
if 'mysql' in url_str:
sql = """
SELECT
COLUMN_NAME AS column_name,
DATA_TYPE AS data_type,
IS_NULLABLE AS is_nullable,
COLUMN_DEFAULT AS column_default,
COLUMN_COMMENT AS column_comment,
CHARACTER_MAXIMUM_LENGTH AS max_length,
NUMERIC_PRECISION AS numeric_precision,
NUMERIC_SCALE AS numeric_scale
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
ORDER BY ORDINAL_POSITION
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
elif 'postgresql' in url_str:
sql = """
SELECT
a.attname AS column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
NOT a.attnotnull AS is_nullable,
pg_get_expr(d.adbin, d.adrelid) AS column_default,
col_description(a.attrelid, a.attnum) AS column_comment
FROM pg_attribute a
JOIN pg_class c ON a.attrelid = c.oid
LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
WHERE c.relname = :table_name
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
elif 'mssql' in url_str:
sql = """
SELECT
c.name AS column_name,
t.name AS data_type,
CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable,
OBJECT_DEFINITION(c.default_object_id) AS column_default,
CAST(ep.value AS NVARCHAR(MAX)) AS column_comment,
COLUMNPROPERTY(c.object_id, c.name, 'charmaxlen') AS max_length
FROM sys.columns c
LEFT JOIN sys.types t ON c.user_type_id = t.user_type_id
LEFT JOIN sys.extended_properties ep
ON ep.major_id = c.object_id
AND ep.minor_id = c.column_id
AND ep.name = 'MS_Description'
WHERE c.object_id = OBJECT_ID(:table_name)
ORDER BY c.column_id
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
elif 'oracle' in url_str:
sql = """
SELECT
cols.column_name AS column_name,
cols.data_type AS data_type,
cols.nullable AS is_nullable,
cols.data_default AS column_default,
comm.comments AS column_comment,
cols.data_length AS max_length
FROM user_tab_columns cols
LEFT JOIN user_col_comments comm
ON comm.table_name = cols.table_name
AND comm.column_name = cols.column_name
WHERE cols.table_name = UPPER(:table_name)
ORDER BY cols.column_id
"""
return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()]
else:
raise ValueError("不支持的数据库类型")
except Exception as e:
logger.error(f"获取字段信息失败: {str(e)}")
raise Exception(f"获取字段信息失败: {str(e)}")
def close_connection(self, connection_id: str):
"""关闭数据库连接"""
try:
if connection_id in self.engines:
self.engines[connection_id].dispose()
del self.engines[connection_id]
del self.sessions[connection_id]
logger.info(f"已关闭数据库连接: {connection_id}")
except Exception as e:
logger.error(f"关闭连接失败: {str(e)}")
def list_connections(self) -> List[str]:
"""列出所有活动连接"""
return list(self.engines.keys())
# 全局数据库管理器实例
db_manager = DatabaseManager()