Source code for py_yaml_fixtures.factories.django

from collections import defaultdict
from types import FunctionType
from typing import *

from django.db import models as db

from ..types import Identifier
from .. import utils
from . import FactoryInterface


[docs]class DjangoModelFactory(FactoryInterface): """ Concrete factory for the Django ORM. """ def __init__(self, models: Union[List[type], Dict[str, type]], date_factory: Optional[FunctionType] = None, datetime_factory: Optional[FunctionType] = None): super().__init__() self.models = (models if isinstance(models, dict) else {model.__name__: model for model in models}) self.model_instances = defaultdict(dict) self.datetime_factory = datetime_factory or utils.datetime_factory self.date_factory = date_factory or utils.date_factory self.relations = {model.__name__: set() for model in self.models.values()} for model in self.models.values(): for rel in (list(model._meta.fields) + list(model._meta.related_objects)): if isinstance(rel, db.ManyToManyRel): self.relations[rel.related_model.__name__].add(rel.field.name) elif isinstance(rel, (db.ForeignObject, db.ManyToOneRel)): self.relations[rel.model.__name__].add(rel.name) def create_or_update(self, identifier: Identifier, data: Dict[str, Any], ): if self.model_instances[identifier.class_name].get(identifier.key): return self.model_instances[identifier.class_name][identifier.key], False kwargs, defaults, m2m = {}, {}, {} model_class = self.models[identifier.class_name] for k, v in data.items(): if not hasattr(model_class, k): defaults[k] = v continue field = model_class._meta.get_field(k) if isinstance(field, (db.ManyToManyField, db.ManyToManyRel)): m2m[k] = v continue if field.primary_key or field.unique: kwargs[k] = v else: defaults[k] = v if not kwargs: instance, created = model_class.objects.update_or_create(**defaults) else: instance, created = model_class.objects.update_or_create(**kwargs, defaults=defaults) for k, v in m2m.items(): for obj in v: getattr(instance, k).add(obj) self.model_instances[identifier.class_name][identifier.key] = instance return instance, created def get_relationships(self, class_name: str): return self.relations[class_name] def maybe_convert_values(self, identifier: Identifier, data: Dict[str, Any], ): model_class = self.models[identifier.class_name] relationships = self.get_relationships(identifier.class_name) rv = data.copy() for field_name, value in data.items(): if field_name in relationships: rv[field_name] = self.loader.convert_identifiers(value) continue field = model_class._meta.get_field(field_name) if isinstance(field, db.fields.DateField): rv[field_name] = self.date_factory(value) elif isinstance(field, db.fields.DateTimeField): rv[field_name] = self.datetime_factory(value) return rv