diff --git a/accounts/api/auth.py b/accounts/api/auth.py index 61bf74f..a04d0a8 100644 --- a/accounts/api/auth.py +++ b/accounts/api/auth.py @@ -3,6 +3,8 @@ from django.contrib.auth import get_user_model from rest_framework_simplejwt.tokens import RefreshToken from django.db.models import Q +from invites.models import RegistrationCode + auth_router = Router(tags=["认证"]) User = get_user_model() @@ -13,19 +15,28 @@ def register( username: str = Form(...), password: str = Form(...), email: str = Form(...), - role: str = Form("user") # 可选:默认 user + code: str = Form(None) ): if User.objects.filter(username=username).exists(): return {"success": False, "message": "用户名已存在"} - if role != "user": - return {"success": False, "message": "不能注册管理员或分管理账号"} - - user = User(username=username, email=email, role=role) + user = User(username=username, email=email, role="user") user.set_password(password) user.save() + if code: + try: + reg = RegistrationCode.objects.get(code=code) + if not reg.is_available(): + return {"success": False, "message": "注册码已达使用上限"} + user.authorized_websites.set(reg.manager.managed_websites.all()) + reg.used_count += 1 + reg.save() + except RegistrationCode.DoesNotExist: + return {"success": False, "message": "注册码无效"} + refresh = RefreshToken.for_user(user) + return { "success": True, "message": "注册成功", @@ -40,7 +51,6 @@ def register( } } - @auth_router.post("/login") def login( request, diff --git a/accounts/api/authorize.py b/accounts/api/authorize.py index e5ac0e8..b61ec04 100644 --- a/accounts/api/authorize.py +++ b/accounts/api/authorize.py @@ -12,9 +12,7 @@ from utils.permissions import manager_required, login_required router = Router(tags=["授权管理"]) -# ========================= -# 模型(授权申请) -# ========================= + class WebsiteAccessRequest(models.Model): user = models.ForeignKey(User, on_delete=models.CASCADE) website = models.ForeignKey(Website, on_delete=models.CASCADE) @@ -26,20 +24,17 @@ class WebsiteAccessRequest(models.Model): reason = models.TextField(blank=True) created_at = models.DateTimeField(auto_now_add=True) -# ========================= -# 请求结构 -# ========================= + class AuthorizeIn(Schema): user_id: int = Field(..., description="被授权的用户ID") website_ids: List[int] = Field(..., description="要授权的网站ID列表") + class AccessRequestIn(Schema): website_id: int = Field(...) reason: Optional[str] = Field(None, description="申请原因") -# ========================= -# 授权接口(POST) -# ========================= + @router.post("/authorize", auth=jwt_auth) @manager_required def authorize_user(request, data: AuthorizeIn): @@ -64,9 +59,7 @@ def authorize_user(request, data: AuthorizeIn): "message": f"已授权 {target_user.username} 访问 {len(data.website_ids)} 个网站", } -# ========================= -# 用户发起申请(POST) -# ========================= + @router.post("/apply", auth=jwt_auth) @login_required def request_access(request, data: AccessRequestIn): @@ -81,9 +74,7 @@ def request_access(request, data: AccessRequestIn): return {"success": True, "message": "申请已提交,等待分管理审批"} -# ========================= -# 分管理查看待审批列表 -# ========================= + @router.get("/pending", auth=jwt_auth) @manager_required def list_pending_requests(request): @@ -106,9 +97,7 @@ def list_pending_requests(request): ] } -# ========================= -# 分管理审批接口 -# ========================= + @router.post("/approve", auth=jwt_auth) @manager_required def approve_request(request, request_id: int = Query(...), approve: bool = Query(True)): @@ -124,3 +113,17 @@ def approve_request(request, request_id: int = Query(...), approve: bool = Query r.user.authorized_websites.add(r.website) return {"success": True, "message": f"已{'通过' if approve else '拒绝'} {r.user.username} 的访问申请"} + + +@router.get("/my-sites", auth=jwt_auth) +@login_required +def list_my_authorized_websites(request): + user = request.user + sites = user.authorized_websites.all().values("id", "name", "db_alias") + return {"success": True, "websites": list(sites)} + + +@router.get("/public-sites") +def list_public_websites(request): + websites = Website.objects.all().values("id", "name", "db_alias", "description") + return {"success": True, "websites": list(websites)} diff --git a/accounts/apps.py b/accounts/apps.py index 3e3c765..c8f1887 100644 --- a/accounts/apps.py +++ b/accounts/apps.py @@ -4,3 +4,6 @@ from django.apps import AppConfig class AccountsConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'accounts' + + def ready(self): + import accounts.signals \ No newline at end of file diff --git a/accounts/signals.py b/accounts/signals.py new file mode 100644 index 0000000..0716404 --- /dev/null +++ b/accounts/signals.py @@ -0,0 +1,19 @@ +from django.db.models.signals import post_save +from django.dispatch import receiver +from accounts.models import User +from invites.models import RegistrationCode +import uuid + + +@receiver(post_save, sender=User) +def create_registration_code_for_manager(sender, instance, created, **kwargs): + if instance.role == "manager": + if created or not RegistrationCode.objects.filter(manager=instance).exists(): + RegistrationCode.objects.create( + code=str(uuid.uuid4()).replace("-", "")[:12], + manager=instance, + description=f"{instance.username} 的默认邀请码", + usage_limit=10 + ) + elif instance.role == "user": + RegistrationCode.objects.filter(manager=instance).update(usage_limit=0) diff --git a/core/settings.py b/core/settings.py index 5e389fa..da8a872 100644 --- a/core/settings.py +++ b/core/settings.py @@ -50,6 +50,7 @@ INSTALLED_APPS = [ 'access_control', 'admin_panel', 'logs', + 'invites' ] MIDDLEWARE = [ diff --git a/invites/__init__.py b/invites/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/invites/admin.py b/invites/admin.py new file mode 100644 index 0000000..51d6052 --- /dev/null +++ b/invites/admin.py @@ -0,0 +1,10 @@ +from django.contrib import admin +from invites.models import RegistrationCode + + +@admin.register(RegistrationCode) +class RegistrationCodeAdmin(admin.ModelAdmin): + list_display = ("code", "manager", "usage_limit", "used_count", "created_at") + list_filter = ("manager",) + search_fields = ("code", "manager__username") + readonly_fields = ("used_count", "created_at") diff --git a/invites/apps.py b/invites/apps.py new file mode 100644 index 0000000..9340cc0 --- /dev/null +++ b/invites/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class InvitesConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "invites" diff --git a/invites/migrations/__init__.py b/invites/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/invites/models.py b/invites/models.py new file mode 100644 index 0000000..cf0d138 --- /dev/null +++ b/invites/models.py @@ -0,0 +1,21 @@ +from django.db import models +from accounts.models import User + +class RegistrationCode(models.Model): + code = models.CharField(max_length=32, unique=True, verbose_name="注册码") + manager = models.ForeignKey( + User, + on_delete=models.CASCADE, + limit_choices_to={"role": "manager"}, + verbose_name="对应分管理" + ) + description = models.CharField(max_length=100, blank=True, verbose_name="说明") + usage_limit = models.IntegerField(default=1, verbose_name="最多使用次数") + used_count = models.IntegerField(default=0, verbose_name="已使用次数") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间") + + def __str__(self): + return f"{self.code} ({self.used_count}/{self.usage_limit})" + + def is_available(self): + return self.used_count < self.usage_limit diff --git a/invites/tests.py b/invites/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/invites/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/invites/views.py b/invites/views.py new file mode 100644 index 0000000..91ea44a --- /dev/null +++ b/invites/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/utils/permissions.py b/utils/permissions.py index 84cba5c..ed2f6f5 100644 --- a/utils/permissions.py +++ b/utils/permissions.py @@ -1,36 +1,32 @@ from functools import wraps from ninja.errors import HttpError - -from functools import wraps -from ninja.errors import HttpError - - def login_required(func): @wraps(func) def wrapper(request, *args, **kwargs): - user = getattr(request, 'user', None) + user = getattr(request, 'auth', None) if not user or not user.is_authenticated: raise HttpError(401, "请先登录") + request.user = user return func(request, *args, **kwargs) return wrapper - def manager_required(func): @wraps(func) def wrapper(request, *args, **kwargs): - user = getattr(request, 'user', None) + user = getattr(request, 'auth', None) if not user or not user.is_authenticated or user.role not in ['admin', 'manager']: raise HttpError(403, "仅分管理或管理员可访问") + request.user = user return func(request, *args, **kwargs) return wrapper - def admin_required(func): @wraps(func) def wrapper(request, *args, **kwargs): - user = getattr(request, 'user', None) + user = getattr(request, 'auth', None) if not user or not user.is_authenticated or user.role != 'admin': raise HttpError(403, "仅管理员可访问") + request.user = user return func(request, *args, **kwargs) return wrapper