diff --git a/main.py b/main.py index e69de29..b2a58dc 100644 --- a/main.py +++ b/main.py @@ -0,0 +1,29 @@ +import asyncio +import os +import logging + +from utils.db.db import DbConnector +from utils.db.db_mtr_network_nodes import MtrAdditionalNodeQuery + +# Change log level for each part of code +logging.getLogger('').setLevel(logging.DEBUG) +#logging.basicConfig(level=logging.DEBUG) + +DB_TYPE=os.environ.setdefault('DB_TYPE', 'mysql') +DB_NAME=os.getenv("DB_NAME") +DB_HOST=os.getenv("DB_HOST") +DB_PASS=os.getenv("DB_PASS") +DB_PORT=os.getenv("DB_PORT") +DB_USER=os.getenv("DB_USER") + +async def main(): + logging.debug("Connecting to LANDSLIDE MYSQL DATABASE") + db = DbConnector("spirent_mysql") + db.init(DB_TYPE, DB_HOST, int(DB_PORT), DB_NAME, DB_USER, DB_PASS) + # + mtr_additional_query = MtrAdditionalNodeQuery() + all_nodes = await mtr_additional_query.get_all_nodes() + [print(i.as_dict()) for i in all_nodes] + +if __name__ == "__main__": + loop = asyncio.run(main()) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/model_mtr_network_nodes.py b/models/model_mtr_network_nodes.py new file mode 100644 index 0000000..ea62606 --- /dev/null +++ b/models/model_mtr_network_nodes.py @@ -0,0 +1,40 @@ +from enum import Enum +import json + +from sqlalchemy import Integer, String, Column, Boolean, DateTime +from sqlalchemy.orm import declarative_base +from models.models_helpers import DeclarativeBaseToDict + + +class MtrAdditionalNodesModel(declarative_base(), DeclarativeBaseToDict): + __tablename__ = 'mtr_additional_nodes' + + id = Column(Integer, primary_key=True, index=True) + test_servers = Column(Boolean, default=False) + name = Column(String(50)) + hostname = Column(String(100)) + enabled = Column(Boolean) + description = Column(String(500)) + parameters = Column(String(200)) + + class Config: + orm_mode = True + +class MtrNodesReport(Enum): + STATUS_NO_CHANGE = "no" + STATUS_NEW = "new" + STATUS_CHANGE = "status or hostname change" + STATUS_NOT_CHECKED = "deleted" + + +class MtrNodesModel(declarative_base(), DeclarativeBaseToDict): + __tablename__ = 'mtr_nodes' + + id = Column(Integer, primary_key=True, index=True) + from_tas = Column(String(255)) + to_server_name = Column(String(255)) + to_server_hostname = Column(String(255)) + status = Column(String(100)) + last_change = Column(DateTime) + report = Column(String(255)) + parameters = Column(String(255)) diff --git a/models/models_helpers.py b/models/models_helpers.py new file mode 100644 index 0000000..37f2345 --- /dev/null +++ b/models/models_helpers.py @@ -0,0 +1,5 @@ + + +class DeclarativeBaseToDict: + def as_dict(self) -> dict: + return {i: self.__dict__[i] for i in self.__dict__ if i[0] != '_'} \ No newline at end of file diff --git a/requirements b/requirements new file mode 100644 index 0000000..12c2cf1 --- /dev/null +++ b/requirements @@ -0,0 +1,3 @@ +sqlalchemy[asyncio]>=2.0.41 +aiomysql>=0.2.0 + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/db/__init__.py b/utils/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/db/db.py b/utils/db/db.py new file mode 100644 index 0000000..3552933 --- /dev/null +++ b/utils/db/db.py @@ -0,0 +1,59 @@ + +import logging +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.pool import NullPool + +class DbConnector: + _instances = {} + + def __new__(cls, db_id="default"): + # print("==================") + # [print(f"{i} -> {id(cls._instances[i])}") for i in cls._instances] + # print("==================") + if db_id not in cls._instances: + cls._instances[db_id] = super().__new__(cls) + return cls._instances[db_id] + + def __init__(self, db_id:str="default"): + self.logger = logging.getLogger("spirent_service_db_connector") + + def init(self, db_type:str, host:str, port:int, db_name:str, user:str, password:str): + """ + + :param db_type: Example: mysql, postgresql, ... + :param host: + :param port: + :param db_name: + :param user: + :param password: + """ + + self.db_type = db_type + self.host = host + self.port = port + self.name = db_name + self.user = user + self.password = password + + protocol = { + 'mysql': 'mysql+aiomysql', + 'postgresql': 'postgresql+asyncpg', + } + + sqlalchemy_database_url = f"{protocol[self.db_type]}://{user}:{password}@{host}:{port}/{db_name}" + # for security reason, password not print in logs. + sqlalchemy_database_url_log = f"mysql://{user}:******@{host}:{port}/{db_name}" + self.logger.debug(f"Connecting to db {sqlalchemy_database_url_log}") + + self.engine = create_async_engine(sqlalchemy_database_url, poolclass=NullPool) + self.async_session_local = async_sessionmaker(bind=self.engine, + expire_on_commit=False, + class_=AsyncSession, + autoflush=False, + ) + + def get_db(self): + return self.async_session_local + + + diff --git a/utils/db/db_mtr_network_nodes.py b/utils/db/db_mtr_network_nodes.py new file mode 100644 index 0000000..494fe1b --- /dev/null +++ b/utils/db/db_mtr_network_nodes.py @@ -0,0 +1,44 @@ +from utils.db.db import DbConnector +from models.model_mtr_network_nodes import MtrAdditionalNodesModel + +from sqlalchemy import select + + +class MtrAdditionalNodeQuery: + def __init__(self): + db_con = DbConnector("spirent_mysql") + self.db = db_con.get_db() + + async def get_all_nodes(self) -> list: + """ + Retrieve all entries from the mtr_network_nodes database table. + """ + async with self.db() as session: + result = await session.execute(select(MtrAdditionalNodesModel)) + all_nodes = result.scalars().all() + #await session.close() + return all_nodes + + def get_all_nodes_for_testing(self) -> list: + """ + Retrieve all entries from the mtr_network_nodes database table, exclude rows used for storing parameters for + testing TAS servers. + """ + + all_nodes = (self.db.query(MtrAdditionalNodesModel) + .filter(MtrAdditionalNodesModel.test_servers == False) + .filter(MtrAdditionalNodesModel.enabled == True) + .all()) + + return all_nodes + + def get_mtr_config_for_servers(self) -> list: + all_nodes = (self.db.query(MtrAdditionalNodesModel) + .filter(MtrAdditionalNodesModel.test_servers == True) + .filter(MtrAdditionalNodesModel.enabled == True) + .all()) + + return all_nodes + + def get_node_by_id(self, node_id: int) -> MtrAdditionalNodesModel: + return self.db.query(MtrAdditionalNodesModel).filter_by(id=node_id).one()