Source code for py_yaml_fixtures.factories.sqlalchemy

from collections import defaultdict
from datetime import date, datetime, time, timedelta
from functools import lru_cache
from types import FunctionType
from typing import *

from sqlalchemy import orm as sa_orm
from sqlalchemy.ext.associationproxy import AssociationProxy

from ..types import Identifier
from .. import utils
from . import FactoryInterface


[docs]class SQLAlchemyModelFactory(FactoryInterface): """ Concrete factory for the SQLAlchemy ORM. """ def __init__(self, session: sa_orm.Session, models: Union[List[type], Dict[str, type]], date_factory: Optional[FunctionType] = None, datetime_factory: Optional[FunctionType] = None): """ :param session: the sqlalchemy session :param models: list of model classes, or dictionary of models by name :param date_factory: function used to generate dates (takes one parameter, the text value to convert) :param datetime_factory: function used to generate datetimes (takes one parameter, the text value to convert) """ super().__init__() self.session = session self.models = (models if isinstance(models, dict) else {model.__name__: model for model in models}) self.model_instances = defaultdict(dict) self.datetime_factory = datetime_factory or utils.datetime_factory self.date_factory = date_factory or utils.date_factory def create_or_update(self, identifier: Identifier, data: Dict[str, Any]): instance = self._get_existing(identifier, data) created = False if not instance: model_class = self.models[identifier.class_name] instance = model_class(**data) created = True else: for attr, value in data.items(): setattr(instance, attr, value) self.session.add(instance) self.model_instances[identifier.class_name][identifier.key] = instance return instance, created def _get_existing(self, identifier: Identifier, data: Dict[str, Any]): model_class = self.models[identifier.class_name] relationships = self.get_relationships(identifier.class_name) instance = self.model_instances[identifier.class_name].get(identifier.key) if isinstance(instance, model_class) and instance in self.session: return instance # try to filter by primary key or any unique columns filter_kwargs = {} for col in model_class.__mapper__.columns: if col.name in data and (col.primary_key or col.unique): filter_kwargs[col.name] = data[col.name] # otherwise fallback to filtering by values if not filter_kwargs: filter_kwargs = {k: v for k, v in data.items() if (k in relationships and hasattr(v, '__mapper__')) or v is None or isinstance(v, (bool, int, str, float))} if not filter_kwargs: return None filter_expressions = [] for k, v in filter_kwargs.items(): filter_expressions.append(getattr(model_class, k) == v) if k in relationships: for pk in v.__mapper__.primary_key: pk_value = getattr(v, pk.name) if pk_value is None: return None with self.session.no_autoflush: return self.session.query(model_class).filter(*filter_expressions).one_or_none() @lru_cache() def get_relationships(self, class_name: str) -> Set[str]: rv = set() model_class = self.models[class_name] for col_name, value in model_class.__mapper__.all_orm_descriptors.items(): # FIXME: this is apparently needed to make value.impl accessible? getattr(value, 'property', None) if (isinstance(value, AssociationProxy) or (getattr(value, 'impl', None) is not None and value.impl.uses_objects)): rv.add(col_name) return rv def maybe_convert_values(self, identifier: Identifier, data: Dict[str, Any], ) -> Dict[str, Any]: model_class = self.models[identifier.class_name] relationships = self.get_relationships(identifier.class_name) rv = data.copy() for col_name, value in data.items(): col = getattr(model_class, col_name) if col_name in relationships: rv[col_name] = self.loader.convert_identifiers(value) elif not hasattr(col, 'type'): continue elif col.type.python_type == date: rv[col_name] = self.date_factory(value) elif col.type.python_type == time: rv[col_name] = time(*[int(x) for x in value.split(':')]) elif col.type.python_type == datetime: rv[col_name] = self.datetime_factory(value) elif col.type.python_type == timedelta: duration, unit = value.split(" ") rv[col_name] = timedelta(**{unit: float(duration)}) return rv def commit(self): self.session.commit()