51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
from sqlalchemy.ext.asyncio import (create_async_engine, AsyncSession,
|
|
async_sessionmaker, AsyncEngine)
|
|
|
|
|
|
mysql_engine: AsyncEngine
|
|
mysql_SessionLocal: async_sessionmaker
|
|
|
|
def mysql_init(db_type: str, host: str, port: int, db_name: str, user: str, password: str):
|
|
"""Initialize DB engine + session factory"""
|
|
global mysql_engine, mysql_SessionLocal
|
|
|
|
db_type = db_type
|
|
host = host
|
|
port = port
|
|
name = db_name
|
|
user = user
|
|
password = password
|
|
|
|
protocol = {
|
|
'mysql': 'mysql+aiomysql',
|
|
'postgresql': 'postgresql+asyncpg',
|
|
}
|
|
|
|
sqlalchemy_database_url = f"{protocol[db_type]}://{user}:{password}@{host}:{port}/{db_name}"
|
|
print(f"SQLAlchemy: {sqlalchemy_database_url}")
|
|
mysql_engine = create_async_engine(
|
|
sqlalchemy_database_url,
|
|
echo=False,
|
|
pool_size=5,
|
|
max_overflow=10
|
|
)
|
|
print(f"mysql_engine: {mysql_engine}")
|
|
mysql_SessionLocal = async_sessionmaker(
|
|
bind=mysql_engine,
|
|
expire_on_commit=False,
|
|
class_=AsyncSession,
|
|
autoflush=False
|
|
)
|
|
print(f"mysql_SessionLocal: {mysql_SessionLocal}")
|
|
|
|
def get_mysql_engine() -> AsyncEngine:
|
|
return mysql_engine
|
|
|
|
def get_mysql_session_local() -> AsyncSession:
|
|
return mysql_SessionLocal()
|
|
|
|
async def get_mysql_db():
|
|
"""FastAPI dependency for providing DB sessions"""
|
|
async with get_mysql_session_local() as session:
|
|
yield session
|