From ee18409096a214864d46c757104889a4e23bfe14 Mon Sep 17 00:00:00 2001 From: Franklin-F Date: Fri, 4 Jul 2025 21:31:20 +0800 Subject: [PATCH] feat: enhance DB class with queue mode for concurrent processing and improved locking mechanism --- DB.py | 453 +++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 292 insertions(+), 161 deletions(-) diff --git a/DB.py b/DB.py index c59a216..e58057a 100644 --- a/DB.py +++ b/DB.py @@ -9,6 +9,7 @@ from sqlalchemy import ( create_engine, MetaData, Table, Column, BigInteger, Integer, String, Text, DateTime, tuple_ ) +from queue import Queue, Empty from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.exc import OperationalError from sqlalchemy import select @@ -665,245 +666,375 @@ class DBVidcon: return json.loads(item) if item is not None else None - class DBSA: - # ---------- 参数区,可按需调整 ---------- - FLUSH_EVERY_ROWS = 100 # 行数阈值 - FLUSH_INTERVAL = 30 # 秒阈值 - MAX_SQL_RETRY = 3 # 单条 SQL 死锁自旋次数 - SQL_RETRY_BASE_SLEEP = 0.5 # SQL 第一次重试等待 0.5s,指数递增 - FLUSH_RETRY = 3 # 整体 flush 最多尝试次数 - DELAY_ON_FAIL = 10 # flush 失败等待秒数 - DEADLOCK_ERRNO = 1213 # MySQL 死锁错误码 - # ------------------------------------------- + # ======= 可调参数 ======= + FLUSH_EVERY_ROWS = 100 # 行数阈值 + FLUSH_INTERVAL = 30 # 秒阈值 + MAX_SQL_RETRY = 3 # SQL 死锁自旋 + SQL_RETRY_BASE_SLEEP = 0.5 # 自旋退避基数 + FLUSH_RETRY = 3 # flush 整体轮次 + DELAY_ON_FAIL = 10 # flush 失败等待 + DEADLOCK_ERRNO = 1213 # MySQL 死锁码 + LOCK_TIMEOUT = 3 # 互斥锁超时 + # ======================== - _buf_op = [] - _buf_vid = [] - _buf_payload = [] - _last_flush = time.time() - _lock = threading.Lock() + # ----- 缓冲区 ----- + _buf_op = [] + _buf_vid = [] + _buf_payload = [] + _last_flush = time.time() + + # ----- 并发控制 ----- + _lock = threading.Lock() _existing_op_keys = set() _existing_vid_keys = set() - # ---------------- 对外接口 ----------------- + # ----- queue / 后台线程模式 ----- + _queue_mode = False + _queue = Queue() + + # ================== 退回 Redis 模拟 ================== @staticmethod def push_record_many(rows): - """失败时退回 Redis 的占位函数""" - logger.info(f"[退回Redis] cnt={len(rows)}") + logger.warning("[退回Redis] cnt=%d", len(rows)) - # ---------------- 数据写入入口 ------------- + # ---------------------------------------------------- + # 对外主入口 + # ---------------------------------------------------- @classmethod - def upsert_video(cls, data): + def upsert_video(cls, data: dict): + """ + 业务线程/进程调用此方法写入。 + 如果启用了 queue 模式,则把 data 丢队列即可。 + """ + if cls._queue_mode: + cls._queue.put(data) + return + + # ---------- 数据深拷贝 / 默认值 ---------- data = copy.deepcopy(data) data.setdefault("a_id", 0) data.setdefault("is_repeat", 3) data.setdefault("keyword", "") data["sort"] = data.get("index", 0) - now_ts = int(time.time()) - op_index_key = (data["v_xid"] or "", data["keyword"] or "", now_ts) + op_index_key = (data["v_xid"] or "", data["keyword"] or "", now_ts) vid_index_key = (data["v_xid"] or "", data["title"] or "") - # ---------- 加锁写入缓冲 ---------- - with cls._lock: - if op_index_key in cls._existing_op_keys: - logger.debug(f"跳过重复操作记录: {op_index_key}") - return - if vid_index_key in cls._existing_vid_keys: - logger.debug(f"跳过重复视频记录: {vid_index_key}") + # ---------- ① 获取互斥锁 ---------- + if not cls._lock.acquire(timeout=cls.LOCK_TIMEOUT): + logger.error("⚠️ [upsert_video] 获取 cls._lock 超时 %ds", cls.LOCK_TIMEOUT) + return + try: + # ---------- ② 去重 ---------- + if op_index_key in cls._existing_op_keys or vid_index_key in cls._existing_vid_keys: return - # 组装 op_row、vid_row(保持你原逻辑) - op_row = { - "v_id": data["v_id"], - "v_xid": data["v_xid"], - "a_id": data["a_id"], - "level": data["level"], - "name_title": data["title"][:100], - "keyword": data["keyword"], - "is_repeat": data["is_repeat"], - "sort": data["sort"], - "createtime": now_ts, - "updatetime": now_ts, - "operatetime": now_ts, - "batch": data["batch"], - "machine": data.get("machine_id", 0), - "is_piracy": data.get("is_piracy", '3'), - "ts_status": data.get("ts_status", 1), - "rn": data.get("rn", ""), - } + # ---------- ③ 构造 op_row ---------- + op_row = dict( + v_id=data["v_id"], + v_xid=data["v_xid"], + a_id=data["a_id"], + level=data.get("level", 0), + name_title=data["title"][:255], + keyword=data["keyword"], + is_repeat=data["is_repeat"], + sort=data["sort"], + createtime=now_ts, + updatetime=now_ts, + operatetime=now_ts, + batch=data.get("batch", 0), + machine=data.get("machine_id", 0), + is_piracy=data.get("is_piracy", '3'), + ts_status=data.get("ts_status", 1), + rn=data.get("rn", ""), + ) - vid_row = { - "v_id": data["v_id"], - "v_xid": data["v_xid"], - "title": data["title"], - "link": data["link"], - "edition": "", - "duration": str(data["duration"]) if data.get("duration") else '0', - "public_time": data["create_time"], - "cover_pic": data["cover_pic"], - "sort": data["sort"], - "u_xid": data["u_xid"], - "u_id": data["u_id"], - "u_pic": data["u_pic"], - "u_name": data["u_name"], - "status": 1, - "createtime": now_ts, - "updatetime": now_ts, - "operatetime": now_ts, - "watch_number": data.get("view", 0), - "follow_number": data.get("fans", 0), - "video_number": data.get("videos", 0), - } + # ---------- ④ 构造 vid_row ---------- + vid_row = dict( + v_id=data["v_id"], + v_xid=data["v_xid"], + title=data["title"], + link=data["link"], + edition="", + duration=str(data["duration"]) if data.get("duration") else '0', + public_time=data["create_time"], + cover_pic=data["cover_pic"], + sort=data["sort"], + u_xid=data["u_xid"], + u_id=data["u_id"], + u_pic=data["u_pic"], + u_name=data["u_name"], + status=1, + createtime=now_ts, + updatetime=now_ts, + operatetime=now_ts, + watch_number=data.get("view", 0), + follow_number=data.get("fans", 0), + video_number=data.get("videos", 0), + ) + # 只保留 video 表合法字段 + vid_row = {k: v for k, v in vid_row.items() if k in video.c} - # 只保留 video 表中合法字段 - video_fields = {c.name for c in video.columns} - vid_row = {k: v for k, v in vid_row.items() if k in video_fields} - - # 写入缓冲 + # ---------- ⑤ 入缓冲 ---------- cls._buf_op.append(op_row) cls._buf_vid.append(vid_row) cls._buf_payload.append(data) cls._existing_op_keys.add(op_index_key) cls._existing_vid_keys.add(vid_index_key) + finally: + cls._lock.release() - # 判断是否触发 flush - if (len(cls._buf_vid) >= cls.FLUSH_EVERY_ROWS - or time.time() - cls._last_flush >= cls.FLUSH_INTERVAL): - logger.info("落表:达到行数或超时阈值,开始落库") - cls.flush() + # ---------- ⑥ 判断是否触发 flush ---------- + if (len(cls._buf_vid) >= cls.FLUSH_EVERY_ROWS or + time.time() - cls._last_flush >= cls.FLUSH_INTERVAL): + logger.info("落表:达到行数或超时阈值,开始落库") + cls.flush() - # ------------- SQL 安全执行 --------------- + # ---------------------------------------------------- + # 单条 SQL 安全执行:死锁自旋 + 连接池日志 + # ---------------------------------------------------- @classmethod def _safe_execute(cls, statement, desc=""): - """带死锁自旋重试""" for attempt in range(cls.MAX_SQL_RETRY): try: + logger.debug("[%s] 准备借连接", desc) with _engine.begin() as conn: + logger.debug("[%s] 借连接成功", desc) conn.execute(statement) return except Exception as e: err_no = getattr(e.orig, "args", [None])[0] if err_no == cls.DEADLOCK_ERRNO and attempt < cls.MAX_SQL_RETRY - 1: time.sleep(cls.SQL_RETRY_BASE_SLEEP * (attempt + 1)) - logger.warning("[%s] 死锁重试 %d/%d", - desc, attempt + 1, cls.MAX_SQL_RETRY) + logger.warning("[%s] 死锁重试 %d/%d", desc, + attempt + 1, cls.MAX_SQL_RETRY) continue + logger.exception("[%s] 执行 SQL 失败", desc) raise - # ------------- 外层 flush(带整体重试) --------------- + # ---------------------------------------------------- + # flush 外层:整体重试 + # ---------------------------------------------------- @classmethod def flush(cls): for round_no in range(1, cls.FLUSH_RETRY + 1): try: cls._flush_once() - return # 成功即退出 + return except Exception as e: - logger.error("[flush] 第 %d 轮失败: %s", round_no, e) + logger.error("[flush] 第 %d 轮失败:%s", round_no, e) if round_no < cls.FLUSH_RETRY: time.sleep(cls.DELAY_ON_FAIL) logger.info("[flush] 等待 %ds 后重试…", cls.DELAY_ON_FAIL) else: logger.error("[flush] 连续 %d 轮失败,退回 Redis", cls.FLUSH_RETRY) cls.push_record_many(cls._buf_payload) - # 清空缓冲,避免死循环 - cls._buf_op.clear() - cls._buf_vid.clear() - cls._buf_payload.clear() - cls._existing_op_keys.clear() - cls._existing_vid_keys.clear() - cls._last_flush = time.time() + cls._clear_buffers() return - # ------------- 真正写库动作 --------------- + # ---------------------------------------------------- + # 真正写库 + # ---------------------------------------------------- @classmethod def _flush_once(cls): - """一次完整落库流程,任何异常让上层捕获""" - # --- 拷贝缓冲并清空 --- - with cls._lock: + t0 = time.time() + # ---------- 拷贝缓冲并清空 ---------- + if not cls._lock.acquire(timeout=cls.LOCK_TIMEOUT): + raise RuntimeError("flush 未取得 cls._lock,可能死锁") + try: op_rows = cls._buf_op[:] vid_rows = cls._buf_vid[:] payloads = cls._buf_payload[:] - - cls._buf_op.clear() - cls._buf_vid.clear() - cls._buf_payload.clear() - cls._existing_op_keys.clear() - cls._existing_vid_keys.clear() - cls._last_flush = time.time() + cls._clear_buffers() + finally: + cls._lock.release() if not op_rows and not vid_rows: return - # --- 写作者表 ------------------------------------------------- + # ---------- 写 video_author ---------- authors_map = {} now_ts = int(time.time()) - for data in payloads: - u_xid = data.get("u_xid") - if not u_xid: + for d in payloads: + uxid = d.get("u_xid") + if not uxid: continue - authors_map[u_xid] = { - "u_id": data.get("u_id"), - "u_xid": u_xid, - "u_name": data.get("u_name"), - "u_pic": data.get("u_pic"), - "follow_number": data.get("fans", 0), - "v_number": data.get("videos", 0), - "pv_number": 0, - "b_number": 0, - "create_time": datetime.utcnow(), - "update_time": now_ts - } - + authors_map[uxid] = dict( + u_xid=uxid, + u_id=d.get("u_id", 0), + u_name=d.get("u_name"), + u_pic=d.get("u_pic"), + follow_number=d.get("fans", 0), + v_number=d.get("videos", 0), + pv_number=0, + b_number=0, + create_time=datetime.utcnow(), + update_time=now_ts, + ) if authors_map: - stmt_author = mysql_insert(video_author).values(list(authors_map.values())) - upd_author = { - "u_name": stmt_author.inserted.u_name, - "u_pic": stmt_author.inserted.u_pic, - "follow_number": stmt_author.inserted.follow_number, - "v_number": stmt_author.inserted.v_number, - "pv_number": stmt_author.inserted.pv_number, - "b_number": stmt_author.inserted.b_number, - "update_time": stmt_author.inserted.update_time, - } - ondup_author = stmt_author.on_duplicate_key_update(**upd_author) - cls._safe_execute(ondup_author, desc="video_author") + stmt_auth = mysql_insert(video_author).values(list(authors_map.values())) + ondup_auth = stmt_auth.on_duplicate_key_update( + u_name=stmt_auth.inserted.u_name, + u_pic=stmt_auth.inserted.u_pic, + follow_number=stmt_auth.inserted.follow_number, + v_number=stmt_auth.inserted.v_number, + update_time=stmt_auth.inserted.update_time, + ) + cls._safe_execute(ondup_auth, desc="video_author") - # --- 写 video_op ---------------------------------------------- + # ---------- 写 video_op ---------- if op_rows: stmt_op = mysql_insert(video_op).values(op_rows) ondup_op = stmt_op.on_duplicate_key_update( - updatetime=stmt_op.inserted.updatetime, - operatetime=stmt_op.inserted.operatetime, - ts_status=stmt_op.inserted.ts_status, - is_repeat=stmt_op.inserted.is_repeat, + updatetime = stmt_op.inserted.updatetime, + operatetime = stmt_op.inserted.operatetime, + ts_status = stmt_op.inserted.ts_status, + is_repeat = stmt_op.inserted.is_repeat, ) cls._safe_execute(ondup_op, desc="video_op") logger.info("落表:操作记录 %d 条", len(op_rows)) - # --- 写 video -------------------------------------------------- + # ---------- 写 video ---------- if vid_rows: stmt_vid = mysql_insert(video).values(vid_rows) - upd = { - "title": stmt_vid.inserted.title, - "link": stmt_vid.inserted.link, - "edition": stmt_vid.inserted.edition, - "duration": stmt_vid.inserted.duration, - "watch_number": stmt_vid.inserted.watch_number, - "follow_number": stmt_vid.inserted.follow_number, - "video_number": stmt_vid.inserted.video_number, - "public_time": stmt_vid.inserted.public_time, - "cover_pic": stmt_vid.inserted.cover_pic, - "sort": stmt_vid.inserted.sort, - "u_xid": stmt_vid.inserted.u_xid, - "u_id": stmt_vid.inserted.u_id, - "u_pic": stmt_vid.inserted.u_pic, - "u_name": stmt_vid.inserted.u_name, - "status": stmt_vid.inserted.status, - "updatetime": stmt_vid.inserted.updatetime, - "operatetime": stmt_vid.inserted.operatetime, - } - ondup_vid = stmt_vid.on_duplicate_key_update(**upd) + ondup_vid = stmt_vid.on_duplicate_key_update( + title = stmt_vid.inserted.title, + link = stmt_vid.inserted.link, + edition = stmt_vid.inserted.edition, + duration = stmt_vid.inserted.duration, + watch_number = stmt_vid.inserted.watch_number, + follow_number = stmt_vid.inserted.follow_number, + video_number = stmt_vid.inserted.video_number, + public_time = stmt_vid.inserted.public_time, + cover_pic = stmt_vid.inserted.cover_pic, + sort = stmt_vid.inserted.sort, + u_xid = stmt_vid.inserted.u_xid, + u_id = stmt_vid.inserted.u_id, + u_pic = stmt_vid.inserted.u_pic, + u_name = stmt_vid.inserted.u_name, + status = stmt_vid.inserted.status, + updatetime = stmt_vid.inserted.updatetime, + operatetime = stmt_vid.inserted.operatetime, + ) cls._safe_execute(ondup_vid, desc="video") - logger.info("落表:视频记录 %d 条", len(vid_rows)) \ No newline at end of file + logger.info("落表:视频记录 %d 条", len(vid_rows)) + + logger.debug("[flush] 本轮耗时 %.3f s", time.time() - t0) + + # ---------------------------------------------------- + # 清空缓冲 + # ---------------------------------------------------- + @classmethod + def _clear_buffers(cls): + cls._buf_op.clear() + cls._buf_vid.clear() + cls._buf_payload.clear() + cls._existing_op_keys.clear() + cls._existing_vid_keys.clear() + cls._last_flush = time.time() + + # ---------------------------------------------------- + # 4️⃣ 可选:启用后台单线程落库 + # ---------------------------------------------------- + @classmethod + def start_single_flusher(cls): + """ + 启动后台线程 —— 生产线程只喂 queue,flusher 串行写库,彻底避免锁竞争。 + """ + cls._queue_mode = True + + def _worker(): + batch = [] + while True: + try: + data = cls._queue.get(timeout=3) + batch.append(data) + # drain 队列 + while True: + try: + batch.append(cls._queue.get_nowait()) + except Empty: + break + except Empty: + pass # 队列暂时为空 + + if not batch: + continue + + # ---- 把 batch 数据重新写入缓冲(无锁)---- + for d in batch: + cls._buffer_without_lock(d) + batch.clear() + + if len(cls._buf_vid) >= cls.FLUSH_EVERY_ROWS: + cls.flush() + + threading.Thread(target=_worker, daemon=True).start() + logger.info("后台 flusher 线程已启动(单线程写库模式)") + + # ---------------------------------------------------- + # 将 queue 里的数据写入缓冲(不加线程锁) + # ---------------------------------------------------- + @classmethod + def _buffer_without_lock(cls, data): + data = copy.deepcopy(data) + data.setdefault("is_repeat", 3) + data.setdefault("keyword", "") + data["sort"] = data.get("index", 0) + now_ts = int(time.time()) + + op_key = (data["v_xid"] or "", data["keyword"] or "", now_ts) + vid_key = (data["v_xid"] or "", data["title"] or "") + if op_key in cls._existing_op_keys or vid_key in cls._existing_vid_keys: + return + + # —— op_row 同构 —— + op_row = dict( + v_id=data["v_id"], + v_xid=data["v_xid"], + a_id=data.get("a_id", 0), + level=data.get("level", 0), + name_title=data["title"][:255], + keyword=data["keyword"], + is_repeat=data["is_repeat"], + sort=data["sort"], + createtime=now_ts, + updatetime=now_ts, + operatetime=now_ts, + batch=data.get("batch", 0), + machine=data.get("machine_id", 0), + is_piracy=data.get("is_piracy", '3'), + ts_status=data.get("ts_status", 1), + rn=data.get("rn", ""), + ) + vid_row = dict( + v_id=data["v_id"], + v_xid=data["v_xid"], + title=data["title"], + link=data["link"], + edition="", + duration=str(data["duration"]) if data.get("duration") else '0', + public_time=data["create_time"], + cover_pic=data["cover_pic"], + sort=data["sort"], + u_xid=data["u_xid"], + u_id=data["u_id"], + u_pic=data["u_pic"], + u_name=data["u_name"], + status=1, + createtime=now_ts, + updatetime=now_ts, + operatetime=now_ts, + watch_number=data.get("view", 0), + follow_number=data.get("fans", 0), + video_number=data.get("videos", 0), + ) + vid_row = {k: v for k, v in vid_row.items() if k in video.c} + + cls._buf_op.append(op_row) + cls._buf_vid.append(vid_row) + cls._buf_payload.append(data) + cls._existing_op_keys.add(op_key) + cls._existing_vid_keys.add(vid_key)