import json
import copy
import datetime
from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery
from werkzeug.exceptions import NotFound
from sqlalchemy import Column, SmallInteger, TIMESTAMP
from contextlib import contextmanager


class SQLAlchemy(_SQLAlchemy):
    @contextmanager
    def auto_commit(self):
        try:
            yield
            self.session.commit()
        except Exception as e:
            self.session.rollback()
            raise e


class Query(BaseQuery):

    def screening_time(self, mode, aims_time, after: bool = True):
        """筛选时间,此方法跟在 filter/filter_by 之后使用
        :param mode: orm模型
        :param aims_time: 目标时间
        :param after: 从目标时间之后
        :return:query
        """
        if after:
            return self.filter(mode.create_time > aims_time)
        else:
            return self.filter(mode.create_time < aims_time)

    def filter_by(self, **kwargs):
        if 'status' not in kwargs.keys():
            kwargs['status'] = 1
        return super(Query, self).filter_by(**kwargs)

    def get_or_404(self, ident):
        rv = self.get(ident)
        if not rv:
            raise NotFound()
        return rv

    def first_or_404(self):
        rv = self.first()
        if not rv:
            raise NotFound()
        return rv


db = SQLAlchemy(query_class=Query)


class Common(object):
    """orm通用操作"""

    create_time = Column(TIMESTAMP, default=datetime.datetime.now)
    status = Column(SmallInteger, default=0)

    _privacy_fields = {'status'}

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def delete(self):
        """逻辑删除"""
        self.status = 1
        return self
    def menstop(self):
        '''停用'''
        self.visible=0
        return  self
    def menstart(self):
        '''启用'''
        self.visible=1
        return  self
    def start(self):
        '''启用'''
        self.stop=0
    def stop_true(self):
        """逻辑删除"""
        self.stop = 1
        return self

    def delete_true(self):
        """物理删除"""
        db.session.delete(self)
        return self

    def direct_flush_(self):
        """直接预提交"""
        self.direct_add_()
        self.flush_()
        return self

    def flush_(self):
        """预提交，等于提交到数据库内存，还未写入数据库文件"""
        db.session.flush()
        return self

    def direct_add_(self):
        """直接添加事务"""
        db.session.add(self)
        return self

    def direct_commit_(self):
        """直接提交"""
        self.direct_add_()
        db.session.commit()
        return self

    def direct_update_(self):
        """直接更新"""
        db.session.commit()
        return self

    def direct_delete_(self):
        """直接删除"""
        db.session.delete(self)
        db.session.commit()
        db.session.close()

    @staticmethod
    def static_commit_():
        """直接提交.目的是尽量少直接引入db对象,集成在模型内"""
        db.session.commit()

    @staticmethod
    def static_flush_():
        """直接预提交"""
        db.session.flush()

    def set_attrs(self, attrs_dict):
        """批量更新模型的字段数据
        :param attrs_dict: {field:value}
        :return: self
        """
        for key, value in attrs_dict.items():
            setattr(self, key, value)
        return self

    def to_dict_(self, fields: set = None, funcs: list = None) -> dict:
        """返回字典表数据
        :param funcs: 序列化后需要被调用的函数
        :param fields: 允许被序列化的字段
        :return: dict({'field_name': field_value})
        """
        result = dict()
        if fields is None:
            fields = set(name.name for name in self.__table__._columns)
        for column in fields:
            value = getattr(self, column)
            if isinstance(value, datetime.datetime):
                value = value.strftime('%Y-%m-%d %H:%M:%S')
            result[column] = value
        # 通过funcs 添加额外的数据内容
        if funcs:
            for func in funcs:
                func, args, kwargs = func
                getattr(self, func)(result, *args, **kwargs)
        return result

    def serialization(self, increase: set = None, remove: set = None, funcs: list = None) -> dict:
        """序列化指定字段
        :param funcs: 序列化后需要调用的函数名与参数,示例:('func_name', tuple(), dict())
        :param increase: 需要(出增加/显示)的序列化输的字段
        :param remove: 需要(去除/隐藏)的序列化输出的字段
        :return: dict({'field_name': field_value})
        """
        if increase is None:
            increase = set()
        if remove is None:
            remove = set()
        if funcs is None:
            funcs = list()

        fields = copy.copy(self._privacy_fields)  # 拷贝默认隐藏字段,不影响到全局模型的序列化输出
        all_field = set(name.name for name in self.__table__._columns)  # 获得模型所有字段
        fields = fields - increase  # 取消被隐藏的字段
        fields = fields | remove  # 追加需要被隐藏的字段

        all_field = all_field - fields  # 从模型原型所有的可序列化字段中 去除需要被隐藏的字段
        return self.to_dict_(fields=all_field, funcs=funcs)  # 开始序列化

    # @property
    # def check_create_time_today(self):
    #     """检查记录时间是否属于当天内"""
    #     create_time = self.create_time.strftime("%Y-%m-%d")
    #     now = datetime.datetime.now().strftime("%Y-%m-%d")
    #     return create_time == now

    def update_create_time(self, new_time: int = None):
        """更新数据create_time字段
        此方法的作用是,在记录重复数据时,系统只保留一条重复数据时.直接更新时间即可.
        默认更新到当前时间
        :param new_time:指定时间
        """
        self.create_time = datetime.datetime.now()
        self.direct_update_()
        return self

    def __str__(self):
        description = ', '.join([f'{column.name}={getattr(self, column.name)}' for column in self.__table__._columns])
        return f'<{description}>'

    def __repr__(self):
        """想要此特殊方法被模型继承,需要将Common继承顺序排在ORM基类之前"""
        return f'<class \'{self.__class__.__name__}\' id={self.id if self.id else None}>'


