feat: enhance DB class with queue mode for concurrent processing and improved locking mechanism

This commit is contained in:
晓丰 2025-07-04 21:31:20 +08:00
parent 6329f3e39d
commit ee18409096

421
DB.py
View File

@ -9,6 +9,7 @@ from sqlalchemy import (
create_engine, MetaData, Table, Column,
BigInteger, Integer, String, Text, DateTime, tuple_
)
from queue import Queue, Empty
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.exc import OperationalError
from sqlalchemy import select
@ -665,212 +666,225 @@ class DBVidcon:
return json.loads(item) if item is not None else None
class DBSA:
# ---------- 参数区,可按需调整 ----------
# ======= 可调参数 =======
FLUSH_EVERY_ROWS = 100 # 行数阈值
FLUSH_INTERVAL = 30 # 秒阈值
MAX_SQL_RETRY = 3 # 单条 SQL 死锁自旋次数
SQL_RETRY_BASE_SLEEP = 0.5 # SQL 第一次重试等待 0.5s,指数递增
FLUSH_RETRY = 3 # 整体 flush 最多尝试次数
DELAY_ON_FAIL = 10 # flush 失败等待秒数
DEADLOCK_ERRNO = 1213 # MySQL 死锁错误码
# -------------------------------------------
MAX_SQL_RETRY = 3 # SQL 死锁自旋
SQL_RETRY_BASE_SLEEP = 0.5 # 自旋退避基数
FLUSH_RETRY = 3 # flush 整体轮次
DELAY_ON_FAIL = 10 # flush 失败等待
DEADLOCK_ERRNO = 1213 # MySQL 死锁码
LOCK_TIMEOUT = 3 # 互斥锁超时
# ========================
# ----- 缓冲区 -----
_buf_op = []
_buf_vid = []
_buf_payload = []
_last_flush = time.time()
# ----- 并发控制 -----
_lock = threading.Lock()
_existing_op_keys = set()
_existing_vid_keys = set()
# ---------------- 对外接口 -----------------
# ----- queue / 后台线程模式 -----
_queue_mode = False
_queue = Queue()
# ================== 退回 Redis 模拟 ==================
@staticmethod
def push_record_many(rows):
"""失败时退回 Redis 的占位函数"""
logger.info(f"[退回Redis] cnt={len(rows)}")
logger.warning("[退回Redis] cnt=%d", len(rows))
# ---------------- 数据写入入口 -------------
# ----------------------------------------------------
# 对外主入口
# ----------------------------------------------------
@classmethod
def upsert_video(cls, data):
def upsert_video(cls, data: dict):
"""
业务线程/进程调用此方法写入
如果启用了 queue 模式则把 data 丢队列即可
"""
if cls._queue_mode:
cls._queue.put(data)
return
# ---------- 数据深拷贝 / 默认值 ----------
data = copy.deepcopy(data)
data.setdefault("a_id", 0)
data.setdefault("is_repeat", 3)
data.setdefault("keyword", "")
data["sort"] = data.get("index", 0)
now_ts = int(time.time())
op_index_key = (data["v_xid"] or "", data["keyword"] or "", now_ts)
vid_index_key = (data["v_xid"] or "", data["title"] or "")
# ---------- 加锁写入缓冲 ----------
with cls._lock:
if op_index_key in cls._existing_op_keys:
logger.debug(f"跳过重复操作记录: {op_index_key}")
# ---------- ① 获取互斥锁 ----------
if not cls._lock.acquire(timeout=cls.LOCK_TIMEOUT):
logger.error("⚠️ [upsert_video] 获取 cls._lock 超时 %ds", cls.LOCK_TIMEOUT)
return
if vid_index_key in cls._existing_vid_keys:
logger.debug(f"跳过重复视频记录: {vid_index_key}")
try:
# ---------- ② 去重 ----------
if op_index_key in cls._existing_op_keys or vid_index_key in cls._existing_vid_keys:
return
# 组装 op_row、vid_row保持你原逻辑
op_row = {
"v_id": data["v_id"],
"v_xid": data["v_xid"],
"a_id": data["a_id"],
"level": data["level"],
"name_title": data["title"][:100],
"keyword": data["keyword"],
"is_repeat": data["is_repeat"],
"sort": data["sort"],
"createtime": now_ts,
"updatetime": now_ts,
"operatetime": now_ts,
"batch": data["batch"],
"machine": data.get("machine_id", 0),
"is_piracy": data.get("is_piracy", '3'),
"ts_status": data.get("ts_status", 1),
"rn": data.get("rn", ""),
}
# ---------- ③ 构造 op_row ----------
op_row = dict(
v_id=data["v_id"],
v_xid=data["v_xid"],
a_id=data["a_id"],
level=data.get("level", 0),
name_title=data["title"][:255],
keyword=data["keyword"],
is_repeat=data["is_repeat"],
sort=data["sort"],
createtime=now_ts,
updatetime=now_ts,
operatetime=now_ts,
batch=data.get("batch", 0),
machine=data.get("machine_id", 0),
is_piracy=data.get("is_piracy", '3'),
ts_status=data.get("ts_status", 1),
rn=data.get("rn", ""),
)
vid_row = {
"v_id": data["v_id"],
"v_xid": data["v_xid"],
"title": data["title"],
"link": data["link"],
"edition": "",
"duration": str(data["duration"]) if data.get("duration") else '0',
"public_time": data["create_time"],
"cover_pic": data["cover_pic"],
"sort": data["sort"],
"u_xid": data["u_xid"],
"u_id": data["u_id"],
"u_pic": data["u_pic"],
"u_name": data["u_name"],
"status": 1,
"createtime": now_ts,
"updatetime": now_ts,
"operatetime": now_ts,
"watch_number": data.get("view", 0),
"follow_number": data.get("fans", 0),
"video_number": data.get("videos", 0),
}
# ---------- ④ 构造 vid_row ----------
vid_row = dict(
v_id=data["v_id"],
v_xid=data["v_xid"],
title=data["title"],
link=data["link"],
edition="",
duration=str(data["duration"]) if data.get("duration") else '0',
public_time=data["create_time"],
cover_pic=data["cover_pic"],
sort=data["sort"],
u_xid=data["u_xid"],
u_id=data["u_id"],
u_pic=data["u_pic"],
u_name=data["u_name"],
status=1,
createtime=now_ts,
updatetime=now_ts,
operatetime=now_ts,
watch_number=data.get("view", 0),
follow_number=data.get("fans", 0),
video_number=data.get("videos", 0),
)
# 只保留 video 表合法字段
vid_row = {k: v for k, v in vid_row.items() if k in video.c}
# 只保留 video 表中合法字段
video_fields = {c.name for c in video.columns}
vid_row = {k: v for k, v in vid_row.items() if k in video_fields}
# 写入缓冲
# ---------- ⑤ 入缓冲 ----------
cls._buf_op.append(op_row)
cls._buf_vid.append(vid_row)
cls._buf_payload.append(data)
cls._existing_op_keys.add(op_index_key)
cls._existing_vid_keys.add(vid_index_key)
finally:
cls._lock.release()
# 判断是否触发 flush
if (len(cls._buf_vid) >= cls.FLUSH_EVERY_ROWS
or time.time() - cls._last_flush >= cls.FLUSH_INTERVAL):
# ---------- ⑥ 判断是否触发 flush ----------
if (len(cls._buf_vid) >= cls.FLUSH_EVERY_ROWS or
time.time() - cls._last_flush >= cls.FLUSH_INTERVAL):
logger.info("落表:达到行数或超时阈值,开始落库")
cls.flush()
# ------------- SQL 安全执行 ---------------
# ----------------------------------------------------
# 单条 SQL 安全执行:死锁自旋 + 连接池日志
# ----------------------------------------------------
@classmethod
def _safe_execute(cls, statement, desc=""):
"""带死锁自旋重试"""
for attempt in range(cls.MAX_SQL_RETRY):
try:
logger.debug("[%s] 准备借连接", desc)
with _engine.begin() as conn:
logger.debug("[%s] 借连接成功", desc)
conn.execute(statement)
return
except Exception as e:
err_no = getattr(e.orig, "args", [None])[0]
if err_no == cls.DEADLOCK_ERRNO and attempt < cls.MAX_SQL_RETRY - 1:
time.sleep(cls.SQL_RETRY_BASE_SLEEP * (attempt + 1))
logger.warning("[%s] 死锁重试 %d/%d",
desc, attempt + 1, cls.MAX_SQL_RETRY)
logger.warning("[%s] 死锁重试 %d/%d", desc,
attempt + 1, cls.MAX_SQL_RETRY)
continue
logger.exception("[%s] 执行 SQL 失败", desc)
raise
# ------------- 外层 flush带整体重试 ---------------
# ----------------------------------------------------
# flush 外层:整体重试
# ----------------------------------------------------
@classmethod
def flush(cls):
for round_no in range(1, cls.FLUSH_RETRY + 1):
try:
cls._flush_once()
return # 成功即退出
return
except Exception as e:
logger.error("[flush] 第 %d 轮失败: %s", round_no, e)
logger.error("[flush] 第 %d 轮失败%s", round_no, e)
if round_no < cls.FLUSH_RETRY:
time.sleep(cls.DELAY_ON_FAIL)
logger.info("[flush] 等待 %ds 后重试…", cls.DELAY_ON_FAIL)
else:
logger.error("[flush] 连续 %d 轮失败,退回 Redis", cls.FLUSH_RETRY)
cls.push_record_many(cls._buf_payload)
# 清空缓冲,避免死循环
cls._buf_op.clear()
cls._buf_vid.clear()
cls._buf_payload.clear()
cls._existing_op_keys.clear()
cls._existing_vid_keys.clear()
cls._last_flush = time.time()
cls._clear_buffers()
return
# ------------- 真正写库动作 ---------------
# ----------------------------------------------------
# 真正写库
# ----------------------------------------------------
@classmethod
def _flush_once(cls):
"""一次完整落库流程,任何异常让上层捕获"""
# --- 拷贝缓冲并清空 ---
with cls._lock:
t0 = time.time()
# ---------- 拷贝缓冲并清空 ----------
if not cls._lock.acquire(timeout=cls.LOCK_TIMEOUT):
raise RuntimeError("flush 未取得 cls._lock可能死锁")
try:
op_rows = cls._buf_op[:]
vid_rows = cls._buf_vid[:]
payloads = cls._buf_payload[:]
cls._buf_op.clear()
cls._buf_vid.clear()
cls._buf_payload.clear()
cls._existing_op_keys.clear()
cls._existing_vid_keys.clear()
cls._last_flush = time.time()
cls._clear_buffers()
finally:
cls._lock.release()
if not op_rows and not vid_rows:
return
# --- 写作者表 -------------------------------------------------
# ---------- 写 video_author ----------
authors_map = {}
now_ts = int(time.time())
for data in payloads:
u_xid = data.get("u_xid")
if not u_xid:
for d in payloads:
uxid = d.get("u_xid")
if not uxid:
continue
authors_map[u_xid] = {
"u_id": data.get("u_id"),
"u_xid": u_xid,
"u_name": data.get("u_name"),
"u_pic": data.get("u_pic"),
"follow_number": data.get("fans", 0),
"v_number": data.get("videos", 0),
"pv_number": 0,
"b_number": 0,
"create_time": datetime.utcnow(),
"update_time": now_ts
}
authors_map[uxid] = dict(
u_xid=uxid,
u_id=d.get("u_id", 0),
u_name=d.get("u_name"),
u_pic=d.get("u_pic"),
follow_number=d.get("fans", 0),
v_number=d.get("videos", 0),
pv_number=0,
b_number=0,
create_time=datetime.utcnow(),
update_time=now_ts,
)
if authors_map:
stmt_author = mysql_insert(video_author).values(list(authors_map.values()))
upd_author = {
"u_name": stmt_author.inserted.u_name,
"u_pic": stmt_author.inserted.u_pic,
"follow_number": stmt_author.inserted.follow_number,
"v_number": stmt_author.inserted.v_number,
"pv_number": stmt_author.inserted.pv_number,
"b_number": stmt_author.inserted.b_number,
"update_time": stmt_author.inserted.update_time,
}
ondup_author = stmt_author.on_duplicate_key_update(**upd_author)
cls._safe_execute(ondup_author, desc="video_author")
stmt_auth = mysql_insert(video_author).values(list(authors_map.values()))
ondup_auth = stmt_auth.on_duplicate_key_update(
u_name=stmt_auth.inserted.u_name,
u_pic=stmt_auth.inserted.u_pic,
follow_number=stmt_auth.inserted.follow_number,
v_number=stmt_auth.inserted.v_number,
update_time=stmt_auth.inserted.update_time,
)
cls._safe_execute(ondup_auth, desc="video_author")
# --- 写 video_op ----------------------------------------------
# ---------- 写 video_op ----------
if op_rows:
stmt_op = mysql_insert(video_op).values(op_rows)
ondup_op = stmt_op.on_duplicate_key_update(
@ -882,28 +896,145 @@ class DBSA:
cls._safe_execute(ondup_op, desc="video_op")
logger.info("落表:操作记录 %d", len(op_rows))
# --- 写 video --------------------------------------------------
# ---------- 写 video ----------
if vid_rows:
stmt_vid = mysql_insert(video).values(vid_rows)
upd = {
"title": stmt_vid.inserted.title,
"link": stmt_vid.inserted.link,
"edition": stmt_vid.inserted.edition,
"duration": stmt_vid.inserted.duration,
"watch_number": stmt_vid.inserted.watch_number,
"follow_number": stmt_vid.inserted.follow_number,
"video_number": stmt_vid.inserted.video_number,
"public_time": stmt_vid.inserted.public_time,
"cover_pic": stmt_vid.inserted.cover_pic,
"sort": stmt_vid.inserted.sort,
"u_xid": stmt_vid.inserted.u_xid,
"u_id": stmt_vid.inserted.u_id,
"u_pic": stmt_vid.inserted.u_pic,
"u_name": stmt_vid.inserted.u_name,
"status": stmt_vid.inserted.status,
"updatetime": stmt_vid.inserted.updatetime,
"operatetime": stmt_vid.inserted.operatetime,
}
ondup_vid = stmt_vid.on_duplicate_key_update(**upd)
ondup_vid = stmt_vid.on_duplicate_key_update(
title = stmt_vid.inserted.title,
link = stmt_vid.inserted.link,
edition = stmt_vid.inserted.edition,
duration = stmt_vid.inserted.duration,
watch_number = stmt_vid.inserted.watch_number,
follow_number = stmt_vid.inserted.follow_number,
video_number = stmt_vid.inserted.video_number,
public_time = stmt_vid.inserted.public_time,
cover_pic = stmt_vid.inserted.cover_pic,
sort = stmt_vid.inserted.sort,
u_xid = stmt_vid.inserted.u_xid,
u_id = stmt_vid.inserted.u_id,
u_pic = stmt_vid.inserted.u_pic,
u_name = stmt_vid.inserted.u_name,
status = stmt_vid.inserted.status,
updatetime = stmt_vid.inserted.updatetime,
operatetime = stmt_vid.inserted.operatetime,
)
cls._safe_execute(ondup_vid, desc="video")
logger.info("落表:视频记录 %d", len(vid_rows))
logger.debug("[flush] 本轮耗时 %.3f s", time.time() - t0)
# ----------------------------------------------------
# 清空缓冲
# ----------------------------------------------------
@classmethod
def _clear_buffers(cls):
cls._buf_op.clear()
cls._buf_vid.clear()
cls._buf_payload.clear()
cls._existing_op_keys.clear()
cls._existing_vid_keys.clear()
cls._last_flush = time.time()
# ----------------------------------------------------
# 4⃣ 可选:启用后台单线程落库
# ----------------------------------------------------
@classmethod
def start_single_flusher(cls):
"""
启动后台线程 生产线程只喂 queueflusher 串行写库彻底避免锁竞争
"""
cls._queue_mode = True
def _worker():
batch = []
while True:
try:
data = cls._queue.get(timeout=3)
batch.append(data)
# drain 队列
while True:
try:
batch.append(cls._queue.get_nowait())
except Empty:
break
except Empty:
pass # 队列暂时为空
if not batch:
continue
# ---- 把 batch 数据重新写入缓冲(无锁)----
for d in batch:
cls._buffer_without_lock(d)
batch.clear()
if len(cls._buf_vid) >= cls.FLUSH_EVERY_ROWS:
cls.flush()
threading.Thread(target=_worker, daemon=True).start()
logger.info("后台 flusher 线程已启动(单线程写库模式)")
# ----------------------------------------------------
# 将 queue 里的数据写入缓冲(不加线程锁)
# ----------------------------------------------------
@classmethod
def _buffer_without_lock(cls, data):
data = copy.deepcopy(data)
data.setdefault("is_repeat", 3)
data.setdefault("keyword", "")
data["sort"] = data.get("index", 0)
now_ts = int(time.time())
op_key = (data["v_xid"] or "", data["keyword"] or "", now_ts)
vid_key = (data["v_xid"] or "", data["title"] or "")
if op_key in cls._existing_op_keys or vid_key in cls._existing_vid_keys:
return
# —— op_row 同构 ——
op_row = dict(
v_id=data["v_id"],
v_xid=data["v_xid"],
a_id=data.get("a_id", 0),
level=data.get("level", 0),
name_title=data["title"][:255],
keyword=data["keyword"],
is_repeat=data["is_repeat"],
sort=data["sort"],
createtime=now_ts,
updatetime=now_ts,
operatetime=now_ts,
batch=data.get("batch", 0),
machine=data.get("machine_id", 0),
is_piracy=data.get("is_piracy", '3'),
ts_status=data.get("ts_status", 1),
rn=data.get("rn", ""),
)
vid_row = dict(
v_id=data["v_id"],
v_xid=data["v_xid"],
title=data["title"],
link=data["link"],
edition="",
duration=str(data["duration"]) if data.get("duration") else '0',
public_time=data["create_time"],
cover_pic=data["cover_pic"],
sort=data["sort"],
u_xid=data["u_xid"],
u_id=data["u_id"],
u_pic=data["u_pic"],
u_name=data["u_name"],
status=1,
createtime=now_ts,
updatetime=now_ts,
operatetime=now_ts,
watch_number=data.get("view", 0),
follow_number=data.get("fans", 0),
video_number=data.get("videos", 0),
)
vid_row = {k: v for k, v in vid_row.items() if k in video.c}
cls._buf_op.append(op_row)
cls._buf_vid.append(vid_row)
cls._buf_payload.append(data)
cls._existing_op_keys.add(op_key)
cls._existing_vid_keys.add(vid_key)