from sqlalchemy import create_engine, text, MetaData, Table, Column, inspect from sqlalchemy.orm import sessionmaker from sqlalchemy.exc import SQLAlchemyError from typing import Dict, List, Any, Optional import logging import oracledb from urllib.parse import quote_plus # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DatabaseManager: """数据库管理类,支持多种数据库类型的连接和操作""" def __init__(self): self.engines = {} self.sessions = {} def create_connection(self, db_type: str, host: str, port: int, username: str, password: str, database: str = None, **kwargs) -> str: """创建数据库连接 Args: db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql) host: 数据库主机地址 port: 数据库端口 username: 用户名 password: 密码 database: 数据库名称 Returns: connection_id: 连接ID """ try: connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs) connection_id = f"{db_type}_{host}_{port}_{database or 'default'}" # 配置引擎参数 engine_kwargs = {"echo": False, "pool_pre_ping": True} if db_type.lower() == 'oracle': # Oracle连接池配置 engine_kwargs.update({ "pool_size": 5, "max_overflow": 10, "pool_timeout": 30, "pool_recycle": 3600, "pool_reset_on_return": "commit" }) engine = create_engine(connection_url, **engine_kwargs) # 测试连接 with engine.connect() as conn: if db_type.lower() == 'oracle': conn.execute(text("SELECT 1 FROM DUAL")) else: conn.execute(text("SELECT 1")) self.engines[connection_id] = engine SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) self.sessions[connection_id] = SessionLocal logger.info(f"成功创建数据库连接: {connection_id}") return connection_id except Exception as e: logger.error(f"创建数据库连接失败: {str(e)}") raise Exception(f"数据库连接失败: {str(e)}") def test_connection(self, db_type: str, host: str, port: int, username: str, password: str, database: str = None, **kwargs) -> Dict[str, Any]: """测试数据库是否可连通 Args: db_type: 数据库类型 (mysql, oracle, sqlserver, postgresql) host: 数据库主机地址 port: 数据库端口 username: 用户名 password: 密码 database: 数据库名称 Returns: Dict[str, Any]: 测试结果信息,包含连接是否成功与版本信息 """ try: # 构建连接URL(内部会对Oracle进行预测试) connection_url = self._build_connection_url(db_type, host, port, username, password, database, **kwargs) engine = create_engine(connection_url, echo=False, pool_pre_ping=True) server_version = None with engine.connect() as conn: # 基本连通性测试 if db_type.lower() == 'oracle': conn.execute(text("SELECT 1 FROM DUAL")) version_sql = "SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1" server_version = conn.execute(text(version_sql)).scalar() elif db_type.lower() == 'mysql': conn.execute(text("SELECT 1")) server_version = conn.execute(text("SELECT VERSION()")).scalar() elif db_type.lower() == 'postgresql': conn.execute(text("SELECT 1")) server_version = conn.execute(text("SELECT version()")).scalar() elif db_type.lower() == 'sqlserver': conn.execute(text("SELECT 1")) server_version = conn.execute(text("SELECT @@VERSION")).scalar() else: raise ValueError(f"不支持的数据库类型: {db_type}") # 释放临时引擎 engine.dispose() return { "ok": True, "db_type": db_type, "connection_url": str(connection_url), "server_version": server_version } except Exception as e: logger.error(f"连接测试失败: {str(e)}") return { "ok": False, "db_type": db_type, "error": str(e) } def _build_connection_url(self, db_type: str, host: str, port: int, username: str, password: str, database: str = None, **kwargs) -> str: """构建数据库连接URL""" # 对用户名和密码进行URL编码,防止特殊字符导致解析错误 encoded_username = quote_plus(username) encoded_password = quote_plus(password) if db_type.lower() == 'mysql': db_part = f"/{database}" if database else "" return f"mysql+pymysql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}?charset=utf8mb4" elif db_type.lower() == 'oracle': # Oracle连接格式 - 根据oracledb文档优化 service_name = database or 'XE' # 先测试Oracle连接以确保参数正确 try: self._test_oracle_connection( host=host, port=port, username=username, password=password, service_name=service_name, **kwargs ) except Exception as e: logger.error(f"Oracle连接预测试失败: {str(e)}") raise e # 使用简化的SQLAlchemy URL格式 # 根据oracledb文档,SQLAlchemy会自动处理连接参数 base_url = f"oracle+oracledb://{encoded_username}:{encoded_password}@{host}:{port}/?service_name={service_name}" return base_url elif db_type.lower() == 'sqlserver': db_part = f"/{database}" if database else "" return f"mssql+pymssql://{encoded_username}:{encoded_password}@{host}:{port}{db_part}" elif db_type.lower() == 'postgresql': db_part = f"/{database}" if database else "/postgres" return f"postgresql+psycopg2://{encoded_username}:{encoded_password}@{host}:{port}{db_part}" else: raise ValueError(f"不支持的数据库类型: {db_type}") def _test_oracle_connection(self, host: str, port: int, username: str, password: str, service_name: str = None, **kwargs): """测试Oracle直接连接""" service_name = service_name or 'XE' # 尝试多种连接方式 connection_methods = [ # 方式1: Easy Connect字符串 { 'name': 'Easy Connect', 'params': { 'user': username, 'password': password, 'dsn': f"{host}:{port}/{service_name}" } }, # 方式2: 分离参数 { 'name': 'Separate Parameters', 'params': { 'user': username, 'password': password, 'host': host, 'port': port, 'service_name': service_name } }, # 方式3: 使用SID { 'name': 'SID Connection', 'params': { 'user': username, 'password': password, 'dsn': oracledb.makedsn(host, port, sid=service_name) } } ] last_error = None for method in connection_methods: try: logger.info(f"尝试Oracle连接方式: {method['name']}") # 尝试连接 connection = oracledb.connect(**method['params']) # 测试查询 cursor = connection.cursor() cursor.execute("SELECT 1 FROM DUAL") result = cursor.fetchone() # 获取数据库版本信息 cursor.execute("SELECT BANNER FROM V$VERSION WHERE ROWNUM = 1") version = cursor.fetchone() cursor.close() connection.close() logger.info(f"Oracle连接成功 ({method['name']}): 查询结果={result}, 版本={version[0] if version else 'Unknown'}") return # 连接成功,退出函数 except Exception as e: last_error = e logger.warning(f"Oracle连接方式 '{method['name']}' 失败: {str(e)}") continue # 所有连接方式都失败 error_msg = f"所有Oracle连接方式都失败。最后一个错误: {str(last_error)}" logger.error(error_msg) raise Exception(error_msg) def get_engine(self, connection_id: str): """获取数据库引擎""" if connection_id not in self.engines: raise ValueError(f"连接ID不存在: {connection_id}") return self.engines[connection_id] def get_session(self, connection_id: str): """获取数据库会话""" if connection_id not in self.sessions: raise ValueError(f"连接ID不存在: {connection_id}") return self.sessions[connection_id]() def execute_query(self, connection_id: str, sql: str, params: Dict = None) -> List[Dict]: """执行查询SQL""" try: engine = self.get_engine(connection_id) with engine.connect() as conn: result = conn.execute(text(sql), params or {}) columns = result.keys() rows = result.fetchall() return [dict(zip(columns, row)) for row in rows] except Exception as e: logger.error(f"执行查询失败: {str(e)}") raise Exception(f"查询执行失败: {str(e)}") def execute_non_query(self, connection_id: str, sql: str, params: Dict = None) -> int: """执行非查询SQL(INSERT, UPDATE, DELETE)""" try: engine = self.get_engine(connection_id) with engine.connect() as conn: with conn.begin(): result = conn.execute(text(sql), params or {}) return result.rowcount except Exception as e: logger.error(f"执行非查询失败: {str(e)}") raise Exception(f"非查询执行失败: {str(e)}") def get_database_info(self, connection_id: str) -> Dict: """获取数据库信息""" try: engine = self.get_engine(connection_id) inspector = inspect(engine) # 获取数据库名称 with engine.connect() as conn: db_name_result = conn.execute(text("SELECT DATABASE()" if 'mysql' in str(engine.url) else "SELECT CURRENT_DATABASE()" if 'postgresql' in str(engine.url) else "SELECT DB_NAME()" if 'mssql' in str(engine.url) else "SELECT SYS_CONTEXT('USERENV', 'DB_NAME') FROM DUAL")) db_name = db_name_result.scalar() # 获取表列表 tables = inspector.get_table_names() return { "database_name": db_name, "tables": tables, "table_count": len(tables) } except Exception as e: logger.error(f"获取数据库信息失败: {str(e)}") raise Exception(f"获取数据库信息失败: {str(e)}") def get_tables_with_comments(self, connection_id: str) -> List[Dict[str, Any]]: """获取数据库中所有表及其备注信息""" try: engine = self.get_engine(connection_id) url_str = str(engine.url) with engine.connect() as conn: if 'mysql' in url_str: sql = """ SELECT TABLE_NAME AS table_name, TABLE_COMMENT AS table_comment FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = DATABASE() ORDER BY TABLE_NAME """ return [dict(row) for row in conn.execute(text(sql)).mappings().all()] elif 'postgresql' in url_str: # 获取当前schema schema = conn.execute(text("SELECT current_schema()")).scalar() or 'public' sql = """ SELECT c.relname AS table_name, obj_description(c.oid) AS table_comment FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = :schema AND c.relkind IN ('r','p') ORDER BY c.relname """ return [dict(row) for row in conn.execute(text(sql), {"schema": schema}).mappings().all()] elif 'mssql' in url_str: sql = """ SELECT t.name AS table_name, CAST(ep.value AS NVARCHAR(MAX)) AS table_comment FROM sys.tables t LEFT JOIN sys.extended_properties ep ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.name = 'MS_Description' ORDER BY t.name """ return [dict(row) for row in conn.execute(text(sql)).mappings().all()] elif 'oracle' in url_str: sql = """ SELECT table_name AS table_name, comments AS table_comment FROM user_tab_comments ORDER BY table_name """ return [dict(row) for row in conn.execute(text(sql)).mappings().all()] else: raise ValueError("不支持的数据库类型") except Exception as e: logger.error(f"获取表备注信息失败: {str(e)}") raise Exception(f"获取表备注信息失败: {str(e)}") def get_table_info(self, connection_id: str, table_name: str) -> Dict: """获取表信息 序列化说明: - SQLAlchemy 的列类型对象(如 `VARCHAR`、`NUMBER` 等)不可直接JSON序列化。 - 本方法将列的 `type` 字段统一转换为字符串表示(例如 `VARCHAR(255)`)。 - 其他由 Inspector 返回的结构(主键、外键、索引)保持原始可序列化格式。 """ try: engine = self.get_engine(connection_id) inspector = inspect(engine) # 获取列信息 raw_columns = inspector.get_columns(table_name) columns = [] for col in raw_columns: # 创建可序列化的列字典 serializable = dict(col) col_type = serializable.get('type') # 将SQLAlchemy类型对象转换为字符串 if col_type is not None: try: serializable['type'] = str(col_type) except Exception: # 兜底:使用类型名 serializable['type'] = type(col_type).__name__ columns.append(serializable) # 获取主键 primary_keys = inspector.get_pk_constraint(table_name) # 获取外键 foreign_keys = inspector.get_foreign_keys(table_name) # 获取索引 indexes = inspector.get_indexes(table_name) return { "table_name": table_name, "columns": columns, "primary_keys": primary_keys, "foreign_keys": foreign_keys, "indexes": indexes } except Exception as e: logger.error(f"获取表信息失败: {str(e)}") raise Exception(f"获取表信息失败: {str(e)}") def get_table_columns(self, connection_id: str, table_name: str) -> List[Dict[str, Any]]: """获取指定表的字段名、类型、备注等信息(兼容多数据库)""" try: engine = self.get_engine(connection_id) url_str = str(engine.url) with engine.connect() as conn: if 'mysql' in url_str: sql = """ SELECT COLUMN_NAME AS column_name, DATA_TYPE AS data_type, IS_NULLABLE AS is_nullable, COLUMN_DEFAULT AS column_default, COLUMN_COMMENT AS column_comment, CHARACTER_MAXIMUM_LENGTH AS max_length, NUMERIC_PRECISION AS numeric_precision, NUMERIC_SCALE AS numeric_scale FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = :table_name ORDER BY ORDINAL_POSITION """ return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] elif 'postgresql' in url_str: sql = """ SELECT a.attname AS column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, NOT a.attnotnull AS is_nullable, pg_get_expr(d.adbin, d.adrelid) AS column_default, col_description(a.attrelid, a.attnum) AS column_comment FROM pg_attribute a JOIN pg_class c ON a.attrelid = c.oid LEFT JOIN pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum WHERE c.relname = :table_name AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """ return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] elif 'mssql' in url_str: sql = """ SELECT c.name AS column_name, t.name AS data_type, CASE WHEN c.is_nullable = 1 THEN 'YES' ELSE 'NO' END AS is_nullable, OBJECT_DEFINITION(c.default_object_id) AS column_default, CAST(ep.value AS NVARCHAR(MAX)) AS column_comment, COLUMNPROPERTY(c.object_id, c.name, 'charmaxlen') AS max_length FROM sys.columns c LEFT JOIN sys.types t ON c.user_type_id = t.user_type_id LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id AND ep.minor_id = c.column_id AND ep.name = 'MS_Description' WHERE c.object_id = OBJECT_ID(:table_name) ORDER BY c.column_id """ return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] elif 'oracle' in url_str: sql = """ SELECT cols.column_name AS column_name, cols.data_type AS data_type, cols.nullable AS is_nullable, cols.data_default AS column_default, comm.comments AS column_comment, cols.data_length AS max_length FROM user_tab_columns cols LEFT JOIN user_col_comments comm ON comm.table_name = cols.table_name AND comm.column_name = cols.column_name WHERE cols.table_name = UPPER(:table_name) ORDER BY cols.column_id """ return [dict(row) for row in conn.execute(text(sql), {"table_name": table_name}).mappings().all()] else: raise ValueError("不支持的数据库类型") except Exception as e: logger.error(f"获取字段信息失败: {str(e)}") raise Exception(f"获取字段信息失败: {str(e)}") def close_connection(self, connection_id: str): """关闭数据库连接""" try: if connection_id in self.engines: self.engines[connection_id].dispose() del self.engines[connection_id] del self.sessions[connection_id] logger.info(f"已关闭数据库连接: {connection_id}") except Exception as e: logger.error(f"关闭连接失败: {str(e)}") def list_connections(self) -> List[str]: """列出所有活动连接""" return list(self.engines.keys()) # 全局数据库管理器实例 db_manager = DatabaseManager()