后端
2022年8月15日
16 分钟阅读
Koa.js 中间件开发最佳实践
Koa.js 中间件开发最佳实践
中间件是 Koa.js 应用的核心,良好的中间件设计可以提升应用的性能、可维护性和安全性。本文将从多个维度探讨 Koa 中间件的开发最佳实践,帮助你编写高质量的中间件代码。
中间件设计原则
单一职责原则
每个中间件应该只做一件事,并且做好:
// ❌ 错误:一个中间件做太多事情
async function doEverything(ctx, next) {
// 设置请求 ID
ctx.state.requestId = uuid();
// 记录日志
console.log(`${ctx.method} ${ctx.url}`);
// 验证认证
if (!ctx.headers.authorization) {
ctx.throw(401);
}
// 处理业务逻辑
ctx.body = await getData();
await next();
}
// ✅ 正确:拆分职责
async function requestId(ctx, next) {
ctx.state.requestId = uuid();
await next();
}
async function logger(ctx, next) {
console.log(`${ctx.method} ${ctx.url}`);
await next();
}
async function auth(ctx, next) {
if (!ctx.headers.authorization) {
ctx.throw(401);
}
await next();
}可复用性
设计通用、可配置的中间件:
// ❌ 硬编码的中间件
async function apiAuth(ctx, next) {
if (ctx.path !== "/api/users") {
await next();
return;
}
// 硬编码的认证逻辑
await next();
}
// ✅ 可复用的中间件
function createAuthMiddleware(options = {}) {
const {
required = true,
getToken = (ctx) => ctx.headers.authorization?.replace("Bearer ", ""),
verifyToken = async (token) => {
// 默认验证逻辑
},
} = options;
return async function auth(ctx, next) {
const token = getToken(ctx);
if (required && !token) {
ctx.throw(401, "Authentication required");
}
if (token) {
try {
ctx.state.user = await verifyToken(token);
} catch (err) {
if (required) {
ctx.throw(401, "Invalid token");
}
}
}
await next();
};
}
// 使用
const auth = createAuthMiddleware({
required: true,
verifyToken: customVerifyToken,
});
app.use(auth);幂等性
中间件应该是幂等的,多次执行结果应该相同:
// ❌ 非幂等:每次执行都会修改全局状态
let requestCount = 0;
async function counter(ctx, next) {
requestCount++; // 全局状态被修改
ctx.state.count = requestCount;
await next();
}
// ✅ 幂等:不依赖全局状态
async function requestCounter(ctx, next) {
// 每个请求都有独立的计数器
ctx.state.localCount = 1;
await next();
}错误处理最佳实践
统一错误格式
// 定义自定义错误类
class AppError extends Error {
constructor(message, statusCode = 500, code = null) {
super(message);
this.statusCode = statusCode;
this.code = code;
this.name = this.constructor.name;
Error.captureStackTrace(this, this.constructor);
}
}
class ValidationError extends AppError {
constructor(message, errors = []) {
super(message, 400, "VALIDATION_ERROR");
this.errors = errors;
}
}
class NotFoundError extends AppError {
constructor(resource = "Resource") {
super(`${resource} not found`, 404, "NOT_FOUND");
}
}
// 错误处理中间件
async function errorHandler(ctx, next) {
try {
await next();
} catch (err) {
// 已知的应用错误
if (err instanceof AppError) {
ctx.status = err.statusCode;
ctx.body = {
error: {
code: err.code,
message: err.message,
...(err.errors && { errors: err.errors }),
},
};
return;
}
// Koa 错误(使用 ctx.throw)
if (err.status && err.expose) {
ctx.status = err.status;
ctx.body = {
error: {
code: "HTTP_ERROR",
message: err.message,
},
};
return;
}
// 未知错误
ctx.status = 500;
ctx.body = {
error: {
code: "INTERNAL_ERROR",
message: "An unexpected error occurred",
...(process.env.NODE_ENV === "development" && {
message: err.message,
stack: err.stack,
}),
},
};
// 记录错误
ctx.app.emit("error", err, ctx);
}
}异步错误处理
// ❌ 错误:无法捕获异步错误
async function unsafeMiddleware(ctx, next) {
await next();
// 如果这里是异步操作出错,不会被外层 try-catch 捕获
setTimeout(() => {
throw new Error("Async error"); // 不会被捕获
}, 100);
}
// ✅ 正确:使用 Promise 处理异步错误
async function safeMiddleware(ctx, next) {
await next();
// 包装异步操作
try {
await new Promise((resolve, reject) => {
setTimeout(() => {
try {
// 异步操作
resolve();
} catch (err) {
reject(err);
}
}, 100);
});
} catch (err) {
ctx.throw(500, err.message);
}
}错误分类处理
async function errorHandler(ctx, next) {
try {
await next();
} catch (err) {
// 根据错误类型分类处理
if (err.name === "ValidationError") {
ctx.status = 400;
ctx.body = {
error: {
type: "validation",
message: err.message,
details: err.details,
},
};
} else if (err.name === "UnauthorizedError") {
ctx.status = 401;
ctx.body = {
error: {
type: "authentication",
message: "Authentication required",
},
};
} else if (err.name === "ForbiddenError") {
ctx.status = 403;
ctx.body = {
error: {
type: "authorization",
message: "Insufficient permissions",
},
};
} else {
// 默认处理
ctx.status = err.status || 500;
ctx.body = {
error: {
type: "server",
message: err.message,
},
};
}
}
}性能优化策略
减少不必要的计算
// ❌ 每次都执行重计算
async function expensiveMiddleware(ctx, next) {
// 每次请求都重新计算
const data = await expensiveDatabaseQuery();
ctx.state.data = data;
await next();
}
// ✅ 使用缓存
const cache = new Map();
const CACHE_TTL = 60000; // 1 分钟
async function cachedMiddleware(ctx, next) {
const cacheKey = ctx.url;
const cached = cache.get(cacheKey);
if (cached && Date.now() - cached.timestamp < CACHE_TTL) {
ctx.state.data = cached.data;
await next();
return;
}
const data = await expensiveDatabaseQuery();
cache.set(cacheKey, { data, timestamp: Date.now() });
ctx.state.data = data;
await next();
}条件执行优化
// ❌ 无条件执行所有逻辑
async function inefficientMiddleware(ctx, next) {
// 即使不需要,也会执行这些操作
const user = await getUser(ctx.state.userId);
const settings = await getSettings(ctx.state.userId);
const permissions = await getPermissions(ctx.state.userId);
ctx.state.user = user;
ctx.state.settings = settings;
ctx.state.permissions = permissions;
await next();
}
// ✅ 按需加载
async function efficientMiddleware(ctx, next) {
// 只在需要时加载数据
if (ctx.path.startsWith("/api/protected")) {
ctx.state.user = await getUser(ctx.state.userId);
}
if (ctx.path.startsWith("/api/admin")) {
ctx.state.permissions = await getPermissions(ctx.state.userId);
}
await next();
}批量处理
// ❌ 串行处理
async function slowMiddleware(ctx, next) {
ctx.state.user = await getUser(ctx.state.userId);
ctx.state.settings = await getSettings(ctx.state.userId);
ctx.state.permissions = await getPermissions(ctx.state.userId);
await next();
}
// ✅ 并行处理
async function fastMiddleware(ctx, next) {
const [user, settings, permissions] = await Promise.all([
getUser(ctx.state.userId),
getSettings(ctx.state.userId),
getPermissions(ctx.state.userId),
]);
ctx.state.user = user;
ctx.state.settings = settings;
ctx.state.permissions = permissions;
await next();
}使用流处理大文件
const fs = require("fs");
const path = require("path");
async function fileStreamMiddleware(ctx, next) {
if (ctx.method !== "GET" || !ctx.query.file) {
await next();
return;
}
const filePath = path.join(__dirname, "uploads", ctx.query.file);
// ✅ 使用流,而不是一次性加载到内存
ctx.type = path.extname(filePath);
ctx.body = fs.createReadStream(filePath);
// 不需要调用 next(),因为已经设置了响应
}安全性最佳实践
输入验证
const Joi = require("joi");
function validate(schema, source = "body") {
return async function validator(ctx, next) {
try {
const data =
source === "body"
? ctx.request.body
: source === "query"
? ctx.query
: ctx.params;
const value = await schema.validateAsync(data, {
abortEarly: false,
stripUnknown: true, // 移除未知字段
});
// 使用验证后的数据
if (source === "body") {
ctx.request.body = value;
} else if (source === "query") {
ctx.query = value;
} else {
ctx.params = value;
}
await next();
} catch (err) {
if (err.isJoi) {
ctx.throw(400, {
message: "Validation failed",
errors: err.details.map((detail) => ({
field: detail.path.join("."),
message: detail.message,
})),
});
}
throw err;
}
};
}
// 使用
const userSchema = Joi.object({
name: Joi.string().required().max(100),
email: Joi.string().email().required(),
age: Joi.number().integer().min(0).max(150),
});
app.use(validate(userSchema, "body"));防止 SQL 注入
const db = require("./db");
// ❌ 危险:SQL 注入
async function unsafeQuery(ctx, next) {
const userId = ctx.params.id;
const user = await db.query(`SELECT * FROM users WHERE id = ${userId}`);
ctx.body = user;
}
// ✅ 安全:使用参数化查询
async function safeQuery(ctx, next) {
const userId = ctx.params.id;
const user = await db.query("SELECT * FROM users WHERE id = ?", [userId]);
ctx.body = user;
}防止 XSS 攻击
const escapeHtml = require("escape-html");
async function xssProtection(ctx, next) {
await next();
// 如果响应是 HTML,需要转义
if (ctx.type === "text/html" && typeof ctx.body === "string") {
// 对于用户输入的内容进行转义
ctx.body = escapeHtml(ctx.body);
}
// 设置安全头
ctx.set("X-Content-Type-Options", "nosniff");
ctx.set("X-Frame-Options", "DENY");
ctx.set("X-XSS-Protection", "1; mode=block");
}速率限制
const rateLimit = require("koa-ratelimit");
// 基于 IP 的速率限制
const limiter = rateLimit({
driver: "memory",
db: new Map(),
duration: 60000, // 1 分钟
errorMessage: "Too many requests, please try again later",
id: (ctx) => ctx.ip,
headers: {
remaining: "Rate-Limit-Remaining",
reset: "Rate-Limit-Reset",
total: "Rate-Limit-Total",
},
max: 100, // 每分钟最多 100 个请求
disableHeader: false,
});
app.use(limiter);
// 更细粒度的速率限制
function createRateLimit(options = {}) {
const { max = 10, window = 60000, key = (ctx) => ctx.ip } = options;
const store = new Map();
return async function rateLimiter(ctx, next) {
const identifier = key(ctx);
const now = Date.now();
const record = store.get(identifier) || { count: 0, reset: now + window };
if (now > record.reset) {
record.count = 0;
record.reset = now + window;
}
if (record.count >= max) {
ctx.status = 429;
ctx.body = {
error: {
message: "Too many requests",
retryAfter: Math.ceil((record.reset - now) / 1000),
},
};
return;
}
record.count++;
store.set(identifier, record);
ctx.set("X-RateLimit-Limit", max);
ctx.set("X-RateLimit-Remaining", max - record.count);
ctx.set("X-RateLimit-Reset", Math.ceil(record.reset / 1000));
await next();
};
}可测试性设计
依赖注入
// ❌ 紧耦合,难以测试
async function hardToTestMiddleware(ctx, next) {
const db = require("./db"); // 硬编码依赖
const user = await db.query("SELECT * FROM users WHERE id = ?", [
ctx.params.id,
]);
ctx.body = user;
await next();
}
// ✅ 可测试:依赖注入
function createUserMiddleware(deps = {}) {
const { db = require("./db") } = deps;
return async function userMiddleware(ctx, next) {
const user = await db.query("SELECT * FROM users WHERE id = ?", [
ctx.params.id,
]);
ctx.body = user;
await next();
};
}
// 使用
const userMiddleware = createUserMiddleware();
app.use(userMiddleware);
// 测试时注入 mock
const mockDb = {
query: jest.fn().mockResolvedValue({ id: 1, name: "Test" }),
};
const testMiddleware = createUserMiddleware({ db: mockDb });纯函数设计
// ✅ 纯函数,易于测试
function calculateDiscount(price, discountPercent) {
return price * (1 - discountPercent / 100);
}
async function discountMiddleware(ctx, next) {
if (ctx.state.user?.isVIP) {
ctx.state.discount = calculateDiscount(ctx.state.price, 20);
}
await next();
}
// 测试
test("calculateDiscount should calculate correctly", () => {
expect(calculateDiscount(100, 20)).toBe(80);
expect(calculateDiscount(100, 10)).toBe(90);
});中间件隔离测试
// middleware/auth.test.js
const auth = require("./auth");
describe("Auth Middleware", () => {
let ctx;
let next;
beforeEach(() => {
ctx = {
headers: {},
state: {},
throw: jest.fn(),
};
next = jest.fn().mockResolvedValue();
});
test("should throw 401 if no token", async () => {
await auth(ctx, next);
expect(ctx.throw).toHaveBeenCalledWith(401, "Unauthorized");
expect(next).not.toHaveBeenCalled();
});
test("should call next if token is valid", async () => {
ctx.headers.authorization = "Bearer valid-token";
ctx.state.user = { id: 1, name: "Test" };
await auth(ctx, next);
expect(next).toHaveBeenCalled();
expect(ctx.throw).not.toHaveBeenCalled();
});
});日志记录最佳实践
结构化日志
const winston = require("winston");
const logger = winston.createLogger({
format: winston.format.json(),
transports: [new winston.transports.Console()],
});
async function structuredLogger(ctx, next) {
const start = Date.now();
const requestId = ctx.state.requestId || uuid();
const logData = {
requestId,
method: ctx.method,
url: ctx.url,
ip: ctx.ip,
userAgent: ctx.headers["user-agent"],
};
logger.info("Request started", logData);
try {
await next();
} catch (err) {
logger.error("Request failed", {
...logData,
error: {
message: err.message,
stack: err.stack,
status: err.status,
},
});
throw err;
}
const duration = Date.now() - start;
logger.info("Request completed", {
...logData,
status: ctx.status,
duration,
});
}敏感信息过滤
function sanitizeForLogging(
obj,
sensitiveKeys = ["password", "token", "secret"]
) {
const sanitized = { ...obj };
for (const key of sensitiveKeys) {
if (sanitized[key]) {
sanitized[key] = "***REDACTED***";
}
}
return sanitized;
}
async function safeLogger(ctx, next) {
const start = Date.now();
// 记录请求(过滤敏感信息)
const requestBody = sanitizeForLogging(ctx.request.body);
logger.info("Request", {
method: ctx.method,
url: ctx.url,
body: requestBody,
});
await next();
// 记录响应
logger.info("Response", {
status: ctx.status,
duration: Date.now() - start,
});
}实际案例:完整的中间件栈
const Koa = require("koa");
const app = new Koa();
// 1. 错误处理
app.use(async (ctx, next) => {
try {
await next();
} catch (err) {
ctx.status = err.status || 500;
ctx.body = {
error: {
code: err.code || "INTERNAL_ERROR",
message: err.message,
},
};
ctx.app.emit("error", err, ctx);
}
});
// 2. 安全头
app.use(async (ctx, next) => {
ctx.set("X-Content-Type-Options", "nosniff");
ctx.set("X-Frame-Options", "DENY");
ctx.set("X-XSS-Protection", "1; mode=block");
ctx.set("Strict-Transport-Security", "max-age=31536000");
await next();
});
// 3. 请求 ID
app.use(async (ctx, next) => {
ctx.state.requestId = uuid();
ctx.set("X-Request-ID", ctx.state.requestId);
await next();
});
// 4. 日志
app.use(structuredLogger);
// 5. 速率限制
app.use(createRateLimit({ max: 100, window: 60000 }));
// 6. 体解析(仅对特定路由)
app.use(bodyParser());
// 7. 认证
app.use(createAuthMiddleware({ required: false }));
// 8. 路由
app.use(router.routes());
app.use(router.allowedMethods());
app.listen(3000);编写高质量中间件的关键在于遵循设计原则、注重安全性和性能,并保持代码的可测试性。通过本文的实践,你可以构建出健壮、高效的 Koa 应用。