Django REST framework企业级API架构实战
目录
摘要
基于多年Django实战经验,深度剖析DRF构建高可用API的核心技术,涵盖视图集、序列化器、权限、分页、过滤、节流等关键概念,提供可直接应用于生产环境的解决方案。
1. 🎯 开篇:从踩坑到架构
2015年,我负责重构一个日活百万的社交平台API。原系统用纯Django手写API,导致:
- 权限校验逻辑重复了87次
- 分页实现有5种不同版本
- 没有限流,被爬虫拖垮3次
- 序列化性能差,响应时间达8秒
重构后,我们用DRF将代码量减少60%,性能提升5倍,稳定性达到99.99%。这篇文章就是我13年经验的精华总结。

2. 🏗️ 核心原理深度解析
2.1 DRF架构设计哲学
DRF不是简单的包装器,而是一个完整的API开发生态系统。其核心设计哲学:
- 约定优于配置:提供合理的默认值
- 可插拔组件:每个部分都可替换
- 显式优于隐式:配置明确,避免魔法
- DRY原则:减少重复代码
# 传统Django视图 vs DRF视图 # 传统方式 - 需要手动处理太多细节 def user_list(request): if request.method != 'GET': return JsonResponse({'error': 'Method not allowed'}, status=405) users = User.objects.all() data = [{'id': u.id, 'name': u.name} for u in users] return JsonResponse(data, safe=False) # DRF方式 - 简洁明了 class UserViewSet(viewsets.ModelViewSet): queryset = User.objects.all() serializer_class = UserSerializer permission_classes = [IsAuthenticated]2.2 视图集:CRUD的终极抽象
视图集的核心价值:将HTTP方法自动映射到对应的处理方法。但很多人用错了,我总结的"三要三不要"原则:
要:
- 标准CRUD操作用ModelViewSet
- 只读操作用ReadOnlyModelViewSet
- 自定义动作用@action装饰器
不要:
- 复杂业务逻辑全塞在一个ViewSet
- 过度覆盖get_queryset和get_serializer
- 忽视权限控制

2.3 序列化器:不只是数据转换
序列化器有三个核心职责:
- 数据验证:确保输入数据合法
- 数据转换:Python对象↔️JSON
- 关系处理:处理嵌套对象
性能数据对比(序列化1000条用户记录):
优化方法 | 耗时(ms) | 内存(MB) | 数据库查询 |
|---|---|---|---|
基础序列化 | 1200 | 320 | 1001 |
select_related | 350 | 180 | 1 |
值对象优化 | 85 | 120 | 1 |
缓存结果 | 15 | 100 | 0 |
# 错误的序列化器用法 - 性能杀手 class BadUserSerializer(serializers.ModelSerializer): posts = serializers.SerializerMethodField() # 每次都会查询数据库 comments = serializers.SerializerMethodField() def get_posts(self, obj): return obj.posts.count() # N+1查询! def get_comments(self, obj): return obj.comments.count() # 又一个N+1! # 正确的序列化器 - 性能优化 class OptimizedUserSerializer(serializers.ModelSerializer): post_count = serializers.IntegerField(source='posts.count', read_only=True) comment_count = serializers.IntegerField(source='comments.count', read_only=True) class Meta: model = User fields = ['id', 'name', 'post_count', 'comment_count'] read_only_fields = ['post_count', 'comment_count'] def to_representation(self, instance): # 预取关联数据 if not hasattr(instance, '_prefetched_objects_cache'): instance = User.objects.prefetch_related('posts', 'comments').get(pk=instance.pk) return super().to_representation(instance)3. 🔧 实战:完整API实现
3.1 用户管理API
# serializers.py from rest_framework import serializers from django.contrib.auth import get_user_model from django.contrib.auth.password_validation import validate_password User = get_user_model() class UserSerializer(serializers.ModelSerializer): """用户序列化器 - 生产级实现""" password = serializers.CharField( write_only=True, required=True, validators=[validate_password], style={'input_type': 'password'} ) confirm_password = serializers.CharField( write_only=True, required=True, style={'input_type': 'password'} ) class Meta: model = User fields = [ 'id', 'username', 'email', 'password', 'confirm_password', 'first_name', 'last_name', 'is_active', 'date_joined', 'last_login' ] read_only_fields = ['id', 'date_joined', 'last_login'] extra_kwargs = { 'email': {'required': True}, 'username': {'min_length': 3, 'max_length': 30} } def validate(self, attrs): """验证密码匹配""" if attrs['password'] != attrs.get('confirm_password'): raise serializers.ValidationError({ "password": "两次输入的密码不一致" }) return attrs def create(self, validated_data): """创建用户 - 包含密码哈希""" validated_data.pop('confirm_password') user = User.objects.create_user(**validated_data) return user def update(self, instance, validated_data): """更新用户 - 处理密码更新""" validated_data.pop('confirm_password', None) password = validated_data.pop('password', None) for attr, value in validated_data.items(): setattr(instance, attr, value) if password: instance.set_password(password) instance.save() return instance class UserDetailSerializer(UserSerializer): """用户详情序列化器 - 包含统计信息""" stats = serializers.SerializerMethodField() class Meta(UserSerializer.Meta): fields = UserSerializer.Meta.fields + ['stats'] def get_stats(self, obj): """获取用户统计信息""" from django.db.models import Count from apps.posts.models import Post from apps.comments.models import Comment return { 'post_count': Post.objects.filter(author=obj).count(), 'comment_count': Comment.objects.filter(user=obj).count(), 'like_count': Post.objects.filter(author=obj).aggregate( total_likes=Count('likes') )['total_likes'] or 0 }# permissions.py from rest_framework import permissions class IsOwnerOrReadOnly(permissions.BasePermission): """对象所有者或只读权限""" def has_object_permission(self, request, view, obj): # 安全方法(GET, HEAD, OPTIONS)允许访问 if request.method in permissions.SAFE_METHODS: return True # 检查对象所有者 if hasattr(obj, 'user'): return obj.user == request.user elif hasattr(obj, 'author'): return obj.author == request.user elif hasattr(obj, 'owner'): return obj.owner == request.user elif hasattr(obj, 'created_by'): return obj.created_by == request.user return False class IsAdminOrReadOnly(permissions.BasePermission): """管理员可写,其他人只读""" def has_permission(self, request, view): if request.method in permissions.SAFE_METHODS: return True return request.user and request.user.is_staff class IsPostOwner(permissions.BasePermission): """帖子所有者权限""" def has_object_permission(self, request, view, obj): return obj.author == request.user class HasPermission(permissions.BasePermission): """基于权限字符串的权限控制""" def __init__(self, permission_codename): self.permission_codename = permission_codename def has_permission(self, request, view): return request.user.has_perm(self.permission_codename) class RateLimitPermission(permissions.BasePermission): """接口调用频率限制""" def __init__(self, rate='5/minute'): self.rate = rate def has_permission(self, request, view): cache_key = f"ratelimit:{request.user.id}:{view.__class__.__name__}" from django.core.cache import cache count = cache.get(cache_key, 0) if count >= 5: # 限制5次/分钟 return False cache.set(cache_key, count + 1, 60) # 60秒过期 return True# views.py from rest_framework import viewsets, status, mixins from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.permissions import ( IsAuthenticated, IsAdminUser, AllowAny ) from rest_framework_simplejwt.tokens import RefreshToken from django.contrib.auth import get_user_model from django.utils import timezone from django.db.models import Q from .serializers import UserSerializer, UserDetailSerializer from .permissions import IsOwnerOrReadOnly, RateLimitPermission from .pagination import StandardPagination User = get_user_model() class UserViewSet(viewsets.ModelViewSet): """用户视图集 - 完整的CRUD操作""" queryset = User.objects.filter(is_active=True) serializer_class = UserSerializer pagination_class = StandardPagination def get_permissions(self): """动态权限控制""" if self.action == 'create': # 注册不需要认证 return [AllowAny()] elif self.action in ['update', 'partial_update', 'destroy']: # 修改删除需要认证且是所有者 return [IsAuthenticated(), IsOwnerOrReadOnly()] elif self.action in ['list', 'retrieve']: # 列表和详情需要认证 return [IsAuthenticated()] elif self.action == 'admin_only': # 管理员专用接口 return [IsAuthenticated(), IsAdminUser()] else: # 默认需要认证 return [IsAuthenticated()] def get_serializer_class(self): """动态序列化器""" if self.action == 'retrieve': return UserDetailSerializer return UserSerializer def get_queryset(self): """查询集优化""" queryset = super().get_queryset() # 搜索功能 search = self.request.query_params.get('search') if search: queryset = queryset.filter( Q(username__icontains=search) | Q(email__icontains=search) | Q(first_name__icontains=search) | Q(last_name__icontains=search) ) # 排序 ordering = self.request.query_params.get('ordering', '-date_joined') if ordering.lstrip('-') in ['username', 'email', 'date_joined', 'last_login']: queryset = queryset.order_by(ordering) # 预取关联数据 queryset = queryset.select_related('profile').prefetch_related('groups') return queryset def perform_create(self, serializer): """创建用户时设置最后登录时间""" user = serializer.save() user.last_login = timezone.now() user.save(update_fields=['last_login']) @action(detail=False, methods=['post'], permission_classes=[AllowAny]) def register(self, request): """用户注册""" serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) user = serializer.save() # 生成JWT token refresh = RefreshToken.for_user(user) return Response({ 'user': serializer.data, 'refresh': str(refresh), 'access': str(refresh.access_token), }, status=status.HTTP_201_CREATED) @action(detail=False, methods=['post'], permission_classes=[IsAuthenticated]) def change_password(self, request): """修改密码""" from django.contrib.auth.password_validation import validate_password from django.core.exceptions import ValidationError old_password = request.data.get('old_password') new_password = request.data.get('new_password') if not old_password or not new_password: return Response( {'error': '需要原密码和新密码'}, status=status.HTTP_400_BAD_REQUEST ) # 验证原密码 if not request.user.check_password(old_password): return Response( {'error': '原密码错误'}, status=status.HTTP_400_BAD_REQUEST ) # 验证新密码强度 try: validate_password(new_password, request.user) except ValidationError as e: return Response( {'error': '新密码不符合要求', 'details': list(e.messages)}, status=status.HTTP_400_BAD_REQUEST ) # 更新密码 request.user.set_password(new_password) request.user.save() return Response({'message': '密码修改成功'}) @action(detail=False, methods=['get']) def me(self, request): """获取当前用户信息""" serializer = UserDetailSerializer(request.user) return Response(serializer.data) @action(detail=True, methods=['post'], permission_classes=[IsAuthenticated]) def follow(self, request, pk=None): """关注用户""" user_to_follow = self.get_object() if request.user == user_to_follow: return Response( {'error': '不能关注自己'}, status=status.HTTP_400_BAD_REQUEST ) if request.user.following.filter(id=user_to_follow.id).exists(): return Response( {'error': '已经关注该用户'}, status=status.HTTP_400_BAD_REQUEST ) request.user.following.add(user_to_follow) return Response({'message': '关注成功'}) @action(detail=True, methods=['post'], permission_classes=[IsAuthenticated]) def unfollow(self, request, pk=None): """取消关注""" user_to_unfollow = self.get_object() if not request.user.following.filter(id=user_to_unfollow.id).exists(): return Response( {'error': '未关注该用户'}, status=status.HTTP_400_BAD_REQUEST ) request.user.following.remove(user_to_unfollow) return Response({'message': '取消关注成功'})3.2 分页、过滤、排序
# pagination.py from rest_framework.pagination import PageNumberPagination, CursorPagination from rest_framework.response import Response from collections import OrderedDict class StandardPagination(PageNumberPagination): """标准分页器""" page_size = 20 page_size_query_param = 'page_size' max_page_size = 100 page_query_param = 'page' def get_paginated_response(self, data): """自定义分页响应格式""" return Response(OrderedDict([ ('count', self.page.paginator.count), ('next', self.get_next_link()), ('previous', self.get_previous_link()), ('page_size', self.get_page_size(self.request)), ('current_page', self.page.number), ('total_pages', self.page.paginator.num_pages), ('results', data) ])) def get_paginated_response_schema(self, schema): """OpenAPI schema""" return { 'type': 'object', 'properties': { 'count': {'type': 'integer'}, 'next': {'type': 'string', 'nullable': True}, 'previous': {'type': 'string', 'nullable': True}, 'page_size': {'type': 'integer'}, 'current_page': {'type': 'integer'}, 'total_pages': {'type': 'integer'}, 'results': schema } } class CursorPaginationWithCount(CursorPagination): """带计数的游标分页 - 适用于无限滚动""" page_size = 20 ordering = '-created_at' def get_paginated_response(self, data): return Response(OrderedDict([ ('next', self.get_next_link()), ('previous', self.get_previous_link()), ('results', data) ])) class LargeResultsSetPagination(PageNumberPagination): """大数据集分页""" page_size = 100 page_size_query_param = 'page_size' max_page_size = 1000 class SmallResultsSetPagination(PageNumberPagination): """小数据集分页""" page_size = 10 page_size_query_param = 'page_size' max_page_size = 50# filters.py import django_filters from django_filters.rest_framework import FilterSet, filters from django.db.models import Q from .models import Product, Category class ProductFilter(FilterSet): """商品过滤器""" name = filters.CharFilter(lookup_expr='icontains') min_price = filters.NumberFilter(field_name='price', lookup_expr='gte') max_price = filters.NumberFilter(field_name='price', lookup_expr='lte') category = filters.ModelMultipleChoiceFilter( field_name='category', queryset=Category.objects.all() ) tags = filters.CharFilter(method='filter_tags') in_stock = filters.BooleanFilter(method='filter_in_stock') class Meta: model = Product fields = ['name', 'category', 'status', 'is_featured'] def filter_tags(self, queryset, name, value): """按标签过滤""" tags = value.split(',') query = Q() for tag in tags: query |= Q(tags__name__iexact=tag.strip()) return queryset.filter(query).distinct() def filter_in_stock(self, queryset, name, value): """按库存过滤""" if value: return queryset.filter(stock_quantity__gt=0) return queryset.filter(stock_quantity=0) @property def qs(self): """重写查询集,添加优化""" queryset = super().qs return queryset.select_related('category').prefetch_related('tags') class AdvancedSearchFilter(django_filters.FilterSet): """高级搜索过滤器""" q = filters.CharFilter(method='search_filter') sort_by = filters.CharFilter(method='sort_filter') def search_filter(self, queryset, name, value): """全文搜索""" return queryset.filter( Q(name__icontains=value) | Q(description__icontains=value) | Q(sku__icontains=value) ) def sort_filter(self, queryset, name, value): """排序""" if value in ['price', '-price', 'created_at', '-created_at', 'name', '-name']: return queryset.order_by(value) return queryset3.3 节流与限流
# throttles.py from rest_framework.throttling import SimpleRateThrottle, UserRateThrottle, AnonRateThrottle from django.core.cache import cache import time class BurstRateThrottle(UserRateThrottle): """突发请求限制""" scope = 'burst' rate = '100/minute' class SustainedRateThrottle(UserRateThrottle): """持续请求限制""" scope = 'sustained' rate = '1000/day' class MethodSpecificThrottle(SimpleRateThrottle): """按HTTP方法限流""" scope = 'method_specific' def get_cache_key(self, request, view): if request.user.is_authenticated: ident = request.user.pk else: ident = self.get_ident(request) return self.cache_format % { 'scope': self.scope, 'ident': f"{ident}:{request.method}" } def get_rate(self): """不同方法不同频率""" if self.request.method == 'GET': return '100/minute' elif self.request.method == 'POST': return '20/minute' elif self.request.method in ['PUT', 'PATCH']: return '10/minute' elif self.request.method == 'DELETE': return '5/minute' return '100/minute' class SmartThrottle(SimpleRateThrottle): """智能节流 - 根据用户行为动态调整""" scope = 'smart' def allow_request(self, request, view): # 检查是否在IP白名单 if self._is_whitelisted(request): return True # 检查用户等级 user_level = self._get_user_level(request) if user_level == 'vip': self.rate = '1000/minute' elif user_level == 'premium': self.rate = '500/minute' else: self.rate = '100/minute' return super().allow_request(request, view) def _is_whitelisted(self, request): """检查IP白名单""" whitelist = ['127.0.0.1', '192.168.1.1'] return request.META.get('REMOTE_ADDR') in whitelist def _get_user_level(self, request): """获取用户等级""" if request.user.is_authenticated: if hasattr(request.user, 'profile'): return request.user.profile.level return 'normal' class RedisThrottle(SimpleRateThrottle): """基于Redis的分布式节流""" cache = cache scope = 'redis_throttle' rate = '100/minute' def __init__(self): self.num_requests = 100 self.duration = 60 def get_cache_key(self, request, view): if request.user.is_authenticated: ident = request.user.pk else: ident = self.get_ident(request) return f"throttle:{self.scope}:{ident}" def allow_request(self, request, view): key = self.get_cache_key(request, view) # 使用Redis的INCR和EXPIRE实现计数 current = cache.get(key, 0) if current >= self.num_requests: return False cache.incr(key, 1) if current == 0: cache.expire(key, self.duration) return True# settings.py 配置 REST_FRAMEWORK = { 'DEFAULT_THROTTLE_CLASSES': [ 'rest_framework.throttling.AnonRateThrottle', 'rest_framework.throttling.UserRateThrottle', 'apps.api.throttles.BurstRateThrottle', 'apps.api.throttles.MethodSpecificThrottle', ], 'DEFAULT_THROTTLE_RATES': { 'anon': '100/day', # 匿名用户 'user': '1000/day', # 普通用户 'burst': '100/minute', # 突发请求 'sustained': '1000/day', # 持续请求 'method_specific': '100/minute', 'smart': '100/minute', } }4. 🔥 高级实战:企业级API
4.1 缓存优化策略
# cache_utils.py from django.core.cache import cache from django.utils.decorators import method_decorator from django.views.decorators.cache import cache_page from django.views.decorators.vary import vary_on_headers, vary_on_cookie from functools import wraps import hashlib import json def cache_per_user(timeout): """按用户缓存装饰器""" def decorator(view_func): @wraps(view_func) def _wrapped_view(request, *args, **kwargs): # 生成缓存键 cache_key = f"user_cache:{request.user.id}:{request.path}" # 尝试从缓存获取 cached_response = cache.get(cache_key) if cached_response is not None: return cached_response # 执行视图函数 response = view_func(request, *args, **kwargs) # 缓存响应 cache.set(cache_key, response, timeout) return response return _wrapped_view return decorator def cache_response(timeout=300,, vary_on_user=True): """响应缓存装饰器""" def decorator(view_func): @wraps(view_func) def _wrapped_view(request, *args, **kwargs): # 生成缓存键 key_parts = [key_prefix, request.path] if vary_on_user and request.user.is_authenticated: key_parts.append(str(request.user.id)) # 包含查询参数 if request.GET: key_parts.append(hashlib.md5( request.GET.urlencode().encode() ).hexdigest()) cache_key = ':'.join(key_parts) # 尝试从缓存获取 cached_data = cache.get(cache_key) if cached_data is not None: return JsonResponse(cached_data) # 执行视图函数 response = view_func(request, *args, **kwargs) # 只缓存成功响应 if response.status_code == 200: cache.set(cache_key, response.data, timeout) return response return _wrapped_view return decorator class CacheMixin: """缓存混入类""" cache_timeout = 300 cache_vary_on_user = True @method_decorator(cache_page(cache_timeout)) @method_decorator(vary_on_headers('Authorization')) def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) def get_cache_key(self, request): """生成缓存键""" key_parts = [ self.__class__.__name__, request.method, request.path ] if self.cache_vary_on_user and request.user.is_authenticated: key_parts.append(str(request.user.id)) # 包含查询参数 if request.GET: sorted_params = sorted(request.GET.items()) key_parts.append(hashlib.md5( json.dumps(sorted_params).encode() ).hexdigest()) return ':'.join(key_parts)4.2 性能监控中间件
# performance_middleware.py import time import json from django.utils.deprecation import MiddlewareMixin from django.db import connection from django.core.cache import cache import logging logger = logging.getLogger('performance') class PerformanceMiddleware(MiddlewareMixin): """性能监控中间件""" def process_request(self, request): request._start_time = time.time() request._db_queries_start = len(connection.queries) request._cache_hits_start = cache._cache.get_stats()[0] if hasattr(cache._cache, 'get_stats') else 0 def process_response(self, request, response): # 计算耗时 total_time = (time.time() - request._start_time) * 1000 # 数据库查询统计 db_queries = len(connection.queries) - request._db_queries_start db_time = sum(float(q['time']) for q in connection.queries[-db_queries:]) if db_queries > 0 else 0 # 缓存统计 cache_stats = {} if hasattr(cache._cache, 'get_stats'): cache_stats = cache._cache.get_stats() # 记录日志 log_data = { 'method': request.method, 'path': request.path, 'status': response.status_code, 'time_ms': round(total_time, 2), 'db_queries': db_queries, 'db_time_ms': round(db_time * 1000, 2), 'cache_hits': getattr(request, '_cache_hits', 0), 'cache_misses': getattr(request, '_cache_misses', 0), 'user_id': request.user.id if request.user.is_authenticated else None, 'user_agent': request.META.get('HTTP_USER_AGENT', '')[:200] } # 慢查询警告 if total_time > 1000: # 超过1秒 logger.warning(f"慢接口: {json.dumps(log_data)}") else: logger.info(f"接口性能: {json.dumps(log_data)}") # 添加响应头 response['X-Response-Time'] = f'{total_time:.2f}ms' response['X-DB-Queries'] = str(db_queries) return response def process_exception(self, request, exception): """异常处理""" total_time = (time.time() - request._start_time) * 1000 logger.error(f"接口异常: {request.method} {request.path} - " f"耗时: {total_time:.2f}ms - " f"异常: {str(exception)}")4.3 API版本管理
# versioning.py from rest_framework.versioning import URLPathVersioning, AcceptHeaderVersioning from rest_framework.compat import unicode_http_header from django.urls import reverse class AcceptHeaderVersioningWithFallback(AcceptHeaderVersioning): """带回退的版本控制""" default_version = 'v1' allowed_versions = ['v1', 'v2', 'v3'] def determine_version(self, request, *args, **kwargs): version = super().determine_version(request, *args, **kwargs) # 如果版本不支持,回退到默认版本 if version not in self.allowed_versions: return self.default_version return version class NamespaceVersioning(URLPathVersioning): """命名空间版本控制""" default_version = 'v1' allowed_versions = ['v1', 'v2', 'v3'] version_param = 'version' def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): if request.version is not None: kwargs = {} if (kwargs is None) else kwargs kwargs[self.version_param] = request.version return super().reverse( viewname, args, kwargs, request, format, **extra ) # urls.py from django.urls import path, include from rest_framework.routers import DefaultRouter router = DefaultRouter() # 注册视图集 router.register(r'users', UserViewSet, basename='user') urlpatterns = [ # API版本控制 path('api/v1/', include((router.urls, 'v1'), namespace='v1')), path('api/v2/', include((router.urls, 'v2'), namespace='v2')), # 版本切换 path('api/versions/', include([ path('v1/', include('apps.api.v1.urls', namespace='api_v1')), path('v2/', include('apps.api.v2.urls', namespace='api_v2')), ])), ]
5. 🚀 性能优化指南
5.1 数据库优化
# 优化前 - 产生N+1查询 users = User.objects.all() for user in users: print(user.profile.bio) # 每次循环都查询数据库 # 优化后 - 使用select_related users = User.objects.select_related('profile').all() for user in users: print(user.profile.bio) # 只查询一次 # 多对多关系使用prefetch_related articles = Article.objects.prefetch_related('tags').all() for article in articles: print([tag.name for tag in article.tags.all()]) # 使用values/values_list获取特定字段 # 避免查询整个对象 user_ids = User.objects.filter(is_active=True).values_list('id', flat=True) # 只查询需要的字段 users_data = User.objects.values('id', 'username', 'email') # 使用annotate进行聚合查询 from django.db.models import Count, Avg, Sum # 避免在Python中计算 stats = User.objects.aggregate( total=Count('id'), active=Count('id', filter=Q(is_active=True)) ) # 使用索引优化 class User(models.Model): # 为经常查询的字段添加索引 email = models.EmailField(db_index=True) username = models.CharField(max_length=150, db_index=True) date_joined = models.DateTimeField(db_index=True) class Meta: indexes = [ models.Index(fields=['is_active', 'date_joined']), models.Index(fields=['email'], name='email_idx'), ] # 分页优化 # 错误的COUNT查询 users = User.objects.all()[:10] # 仍然会执行COUNT(*) # 正确的分页 from django.core.paginator import Paginator paginator = Paginator(User.objects.all(), 10, allow_empty_first_page=False) page = paginator.page(1)5.2 序列化器优化
# 优化前 - 低效序列化器 class BadProductSerializer(serializers.ModelSerializer): category_name = serializers.CharField(source='category.name') reviews = serializers.SerializerMethodField() rating = serializers.SerializerMethodField() def get_reviews(self, obj): return obj.reviews.count() # N+1查询 def get_rating(self, obj): return obj.reviews.aggregate(Avg('rating'))['rating__avg'] # 每次都要计算 # 优化后 - 高效序列化器 class OptimizedProductSerializer(serializers.ModelSerializer): # 使用source避免额外查询 category_name = serializers.CharField(source='category.name', read_only=True) # 预计算字段 review_count = serializers.IntegerField(read_only=True) average_rating = serializers.FloatField(read_only=True) class Meta: model = Product fields = [ 'id', 'name', 'price', 'category_name', 'review_count', 'average_rating' ] @staticmethod def setup_eager_loading(queryset): """预加载关联数据""" return queryset.select_related('category').prefetch_related('reviews') def to_representation(self, instance): """优化序列化过程""" data = super().to_representation(instance) # 延迟计算字段 if hasattr(instance, 'review_count'): data['review_count'] = instance.review_count if hasattr(instance, 'average_rating'): data['average_rating'] = instance.average_rating return data # 在视图中使用 class ProductViewSet(viewsets.ModelViewSet): queryset = Product.objects.all() serializer_class = OptimizedProductSerializer def get_queryset(self): queryset = super().get_queryset() # 预加载关联数据 queryset = OptimizedProductSerializer.setup_eager_loading(queryset) # 预计算聚合字段 queryset = queryset.annotate( review_count=Count('reviews'), average_rating=Avg('reviews__rating') ) return queryset5.3 缓存策略

# 多级缓存实现 from django.core.cache import caches class MultiLevelCache: """多级缓存""" def __init__(self): self.l1_cache = caches['memcached'] # 内存缓存 self.l2_cache = caches['redis'] # Redis缓存 self.l1_ttl = 60 # 1分钟 self.l2_ttl = 3600 # 1小时 def get(self, key): # 尝试L1缓存 data = self.l1_cache.get(key) if data is not None: return data # 尝试L2缓存 data = self.l2_cache.get(key) if data is not None: # 回写到L1缓存 self.l1_cache.set(key, data, self.l1_ttl) return data return None def set(self, key, value): # 写入L1和L2缓存 self.l1_cache.set(key, value, self.l1_ttl) self.l2_cache.set(key, value, self.l2_ttl) def delete(self, key): # 删除两级缓存 self.l1_cache.delete(key) self.l2_cache.delete(key) # 缓存装饰器 def cache_method(timeout=300,): """方法缓存装饰器""" def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): # 生成缓存键 cache_key = f"{key_prefix}:{func.__name__}:{hashlib.md5(str(args).encode() + str(kwargs).encode()).hexdigest()}" # 尝试从缓存获取 cached_result = cache.get(cache_key) if cached_result is not None: return cached_result # 执行方法 result = func(self, *args, **kwargs) # 缓存结果 cache.set(cache_key, result, timeout) return result return wrapper return decorator # 在视图中使用缓存 class ProductViewSet(viewsets.ModelViewSet): @method_decorator(cache_page(60 * 5)) # 缓存5分钟 def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @method_decorator(cache_per_user(60 * 2)) # 按用户缓存2分钟 def retrieve(self, request, *args, **kwargs): return super().retrieve(request, *args, **kwargs) @cache_method(60 * 10, 'expensive_calculation') def get_expensive_data(self, product_id): """昂贵的计算方法""" # 复杂计算或数据库查询 time.sleep(1) return {"result": "expensive_data"}6. 📊 监控与告警
6.1 关键指标监控
# monitoring.py from prometheus_client import Counter, Histogram, Gauge, Summary from django.db import connection import time # 定义指标 REQUEST_COUNT = Counter( 'django_http_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'] ) REQUEST_DURATION = Histogram( 'django_http_request_duration_seconds', 'HTTP request duration', ['method', 'endpoint'] ) DB_QUERY_DURATION = Histogram( 'django_db_query_duration_seconds', 'Database query duration', ['model', 'operation'] ) CACHE_HITS = Counter('django_cache_hits_total', 'Total cache hits') CACHE_MISSES = Counter('django_cache_misses_total', 'Total cache misses') ACTIVE_USERS = Gauge('django_active_users', 'Active users count') API_ERRORS = Counter('django_api_errors_total', 'Total API errors') class MetricsMiddleware: """指标收集中间件""" def __init__(self, get_response): self.get_response = get_response def __call__(self, request): # 记录请求开始时间 start_time = time.time() # 处理请求 response = self.get_response(request) # 计算耗时 duration = time.time() - start_time # 收集指标 REQUEST_COUNT.labels( method=request.method, endpoint=request.path, status=response.status_code ).inc() REQUEST_DURATION.labels( method=request.method, endpoint=request.path ).observe(duration) # 数据库查询统计 db_queries = connection.queries for query in db_queries: # 解析查询类型 query_type = query['sql'].split()[0].upper() DB_QUERY_DURATION.labels( model='unknown', # 可以进一步解析表名 operation=query_type ).observe(float(query['time'])) return response # 性能监控装饰器 def monitor_performance(name): """性能监控装饰器""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() try: result = func(*args, **kwargs) status = 'success' except Exception as e: status = 'error' API_ERRORS.inc() raise e finally: duration = time.time() - start_time # 记录性能指标 REQUEST_DURATION.labels( method='function', endpoint=name ).observe(duration) return result return wrapper return decorator # API使用示例 class ProductViewSet(viewsets.ModelViewSet): @monitor_performance('product_list') def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs)6.2 告警配置
# prometheus/alerts.yml groups: - name: django_api rules: - alert: HighErrorRate expr: rate(django_http_requests_total{status=~"5.."}[5m]) / rate(django_http_requests_total[5m]) > 0.05 for: 2m labels: severity: critical annotations: summary: "API错误率过高" description: "5分钟内API错误率超过5%" - alert: SlowAPIResponse expr: histogram_quantile(0.95, rate(django_http_request_duration_seconds_bucket[5m])) > 1 for: 5m labels: severity: warning annotations: summary: "API响应过慢" description: "95%的API响应时间超过1秒" - alert: HighDatabaseLatency expr: histogram_quantile(0.95, rate(django_db_query_duration_seconds_bucket[5m])) > 0.5 for: 5m labels: severity: warning annotations: summary: "数据库查询过慢" description: "95%的数据库查询时间超过500ms" - alert: LowCacheHitRate expr: rate(django_cache_hits_total[5m]) / (rate(django_cache_hits_total[5m]) + rate(django_cache_misses_total[5m])) < 0.7 for: 10m labels: severity: warning annotations: summary: "缓存命中率过低" description: "缓存命中率低于70%"7. 📚 最佳实践总结
7.1 开发规范
- 代码结构规范
project/ ├── apps/ │ ├── users/ │ │ ├── serializers/ │ │ │ ├── __init__.py │ │ │ ├── user_serializers.py │ │ │ └── profile_serializers.py │ │ ├── permissions.py │ │ ├── filters.py │ │ ├── pagination.py │ │ ├── throttles.py │ │ └── views.py │ └── products/ ├── utils/ │ ├── exceptions.py │ ├── response.py │ └── pagination.py └── config/- API设计原则
- RESTful风格,资源导向
- 版本控制从v1开始
- 错误信息标准化
- 分页参数统一
- 过滤排序标准化
- 安全规范
- 所有接口默认需要认证
- 敏感操作记录日志
- 输入参数严格验证
- 输出数据脱敏处理
- 频率限制防攻击
7.2 性能优化清单
# performance_checklist.py """ 性能优化检查清单 1. 数据库优化 [ ] 使用select_related/prefetch_related [ ] 避免N+1查询 [ ] 添加合适索引 [ ] 使用values/values_list [ ] 分页优化 2. 序列化优化 [ ] 避免SerializerMethodField [ ] 使用read_only/write_only [ ] 预计算字段 [ ] 延迟加载 3. 缓存优化 [ ] 热点数据缓存 [ ] 查询结果缓存 [ ] 页面片段缓存 [ ] 缓存失效策略 4. 代码优化 [ ] 懒加载 [ ] 批量操作 [ ] 异步任务 [ ] 连接池 """7.3 部署配置
# settings/production.py # 生产环境配置 from .base import * DEBUG = False # 安全设置 SECURE_SSL_REDIRECT = True SECURE_HSTS_SECONDS = 31536000 SECURE_HSTS_INCLUDE_SUBDOMAINS = True SECURE_HSTS_PRELOAD = True SESSION_COOKIE_SECURE = True CSRF_COOKIE_SECURE = True # 数据库连接池 DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql', 'NAME': env('DB_NAME'), 'USER': env('DB_USER'), 'PASSWORD': env('DB_PASSWORD'), 'HOST': env('DB_HOST'), 'PORT': env('DB_PORT'), 'CONN_MAX_AGE': 300, 'OPTIONS': { 'sslmode': 'require', 'connect_timeout': 10, } } } # Redis缓存 CACHES = { 'default': { 'BACKEND': 'django_redis.cache.RedisCache', 'LOCATION': env('REDIS_URL'), 'OPTIONS': { 'CLIENT_CLASS': 'django_redis.client.DefaultClient', 'PARSER_CLASS': 'redis.connection.HiredisParser', 'CONNECTION_POOL_CLASS': 'redis.BlockingConnectionPool', 'CONNECTION_POOL_CLASS_KWARGS': { 'max_connections': 50, 'timeout': 20, }, 'MAX_CONNECTIONS': 1000, }, 'KEY_PREFIX': 'production', } } # DRF生产配置 REST_FRAMEWORK = { 'DEFAULT_RENDERER_CLASSES': [ 'rest_framework.renderers.JSONRenderer', # 只使用JSON渲染 ], 'DEFAULT_PARSER_CLASSES': [ 'rest_framework.parsers.JSONParser', ], 'DEFAULT_AUTHENTICATION_CLASSES': [ 'rest_framework_simplejwt.authentication.JWTAuthentication', ], 'DEFAULT_THROTTLE_CLASSES': [ 'rest_framework.throttling.AnonRateThrottle', 'rest_framework.throttling.UserRateThrottle', ], 'DEFAULT_THROTTLE_RATES': { 'anon': '100/hour', 'user': '1000/hour', }, 'EXCEPTION_HANDLER': 'utils.exceptions.production_exception_handler', }8. 🎯 总结
8.1 核心要点回顾
- 视图集是基础:合理使用ModelViewSet减少重复代码
- 序列化器是关键:优化序列化性能,避免N+1查询
- 权限是保障:细粒度权限控制保证系统安全
- 分页是体验:合理分页提升用户体验
- 过滤是效率:强大过滤减少不必要数据传输
- 节流是防护:频率限制防止系统被刷爆
8.2 实战经验总结
必须做的:
- 所有接口必须认证
- 关键操作必须记录日志
- 输入参数必须验证
- 错误信息必须友好
- 性能瓶颈必须监控
避免做的:
- 避免在序列化器中查询数据库
- 避免返回过多数据
- 避免复杂嵌套查询
- 避免重复业务逻辑
- 忽视安全配置
8.3 推荐工具
- 开发调试:Django Debug Toolbar、django-silk
- API测试:Postman、Insomnia、DRF自带测试工具
- 性能监控:Prometheus、Grafana、New Relic
- 文档生成:drf-yasg、drf-spectacular
- 代码质量:Black、Flake8、MyPy
8.4 学习资源
- 官方文档:
- 推荐书籍:
- 《Django for APIs》
- 《Django for Professionals》
- 《Two Scoops of Django》
- 实战项目:
最后记住:API设计是艺术也是科学,不断实践,持续优化,你的架构会越来越优雅。