返回博客列表
后端
2022年8月15日
16 分钟阅读

Koa.js 中间件开发最佳实践

Koa.js 中间件开发最佳实践

中间件是 Koa.js 应用的核心,良好的中间件设计可以提升应用的性能、可维护性和安全性。本文将从多个维度探讨 Koa 中间件的开发最佳实践,帮助你编写高质量的中间件代码。

中间件设计原则

单一职责原则

每个中间件应该只做一件事,并且做好:

// ❌ 错误:一个中间件做太多事情
async function doEverything(ctx, next) {
  // 设置请求 ID
  ctx.state.requestId = uuid();
 
  // 记录日志
  console.log(`${ctx.method} ${ctx.url}`);
 
  // 验证认证
  if (!ctx.headers.authorization) {
    ctx.throw(401);
  }
 
  // 处理业务逻辑
  ctx.body = await getData();
 
  await next();
}
 
// ✅ 正确:拆分职责
async function requestId(ctx, next) {
  ctx.state.requestId = uuid();
  await next();
}
 
async function logger(ctx, next) {
  console.log(`${ctx.method} ${ctx.url}`);
  await next();
}
 
async function auth(ctx, next) {
  if (!ctx.headers.authorization) {
    ctx.throw(401);
  }
  await next();
}

可复用性

设计通用、可配置的中间件:

// ❌ 硬编码的中间件
async function apiAuth(ctx, next) {
  if (ctx.path !== "/api/users") {
    await next();
    return;
  }
  // 硬编码的认证逻辑
  await next();
}
 
// ✅ 可复用的中间件
function createAuthMiddleware(options = {}) {
  const {
    required = true,
    getToken = (ctx) => ctx.headers.authorization?.replace("Bearer ", ""),
    verifyToken = async (token) => {
      // 默认验证逻辑
    },
  } = options;
 
  return async function auth(ctx, next) {
    const token = getToken(ctx);
 
    if (required && !token) {
      ctx.throw(401, "Authentication required");
    }
 
    if (token) {
      try {
        ctx.state.user = await verifyToken(token);
      } catch (err) {
        if (required) {
          ctx.throw(401, "Invalid token");
        }
      }
    }
 
    await next();
  };
}
 
// 使用
const auth = createAuthMiddleware({
  required: true,
  verifyToken: customVerifyToken,
});
 
app.use(auth);

幂等性

中间件应该是幂等的,多次执行结果应该相同:

// ❌ 非幂等:每次执行都会修改全局状态
let requestCount = 0;
 
async function counter(ctx, next) {
  requestCount++; // 全局状态被修改
  ctx.state.count = requestCount;
  await next();
}
 
// ✅ 幂等:不依赖全局状态
async function requestCounter(ctx, next) {
  // 每个请求都有独立的计数器
  ctx.state.localCount = 1;
  await next();
}

错误处理最佳实践

统一错误格式

// 定义自定义错误类
class AppError extends Error {
  constructor(message, statusCode = 500, code = null) {
    super(message);
    this.statusCode = statusCode;
    this.code = code;
    this.name = this.constructor.name;
    Error.captureStackTrace(this, this.constructor);
  }
}
 
class ValidationError extends AppError {
  constructor(message, errors = []) {
    super(message, 400, "VALIDATION_ERROR");
    this.errors = errors;
  }
}
 
class NotFoundError extends AppError {
  constructor(resource = "Resource") {
    super(`${resource} not found`, 404, "NOT_FOUND");
  }
}
 
// 错误处理中间件
async function errorHandler(ctx, next) {
  try {
    await next();
  } catch (err) {
    // 已知的应用错误
    if (err instanceof AppError) {
      ctx.status = err.statusCode;
      ctx.body = {
        error: {
          code: err.code,
          message: err.message,
          ...(err.errors && { errors: err.errors }),
        },
      };
      return;
    }
 
    // Koa 错误(使用 ctx.throw)
    if (err.status && err.expose) {
      ctx.status = err.status;
      ctx.body = {
        error: {
          code: "HTTP_ERROR",
          message: err.message,
        },
      };
      return;
    }
 
    // 未知错误
    ctx.status = 500;
    ctx.body = {
      error: {
        code: "INTERNAL_ERROR",
        message: "An unexpected error occurred",
        ...(process.env.NODE_ENV === "development" && {
          message: err.message,
          stack: err.stack,
        }),
      },
    };
 
    // 记录错误
    ctx.app.emit("error", err, ctx);
  }
}

异步错误处理

// ❌ 错误:无法捕获异步错误
async function unsafeMiddleware(ctx, next) {
  await next();
 
  // 如果这里是异步操作出错,不会被外层 try-catch 捕获
  setTimeout(() => {
    throw new Error("Async error"); // 不会被捕获
  }, 100);
}
 
// ✅ 正确:使用 Promise 处理异步错误
async function safeMiddleware(ctx, next) {
  await next();
 
  // 包装异步操作
  try {
    await new Promise((resolve, reject) => {
      setTimeout(() => {
        try {
          // 异步操作
          resolve();
        } catch (err) {
          reject(err);
        }
      }, 100);
    });
  } catch (err) {
    ctx.throw(500, err.message);
  }
}

错误分类处理

async function errorHandler(ctx, next) {
  try {
    await next();
  } catch (err) {
    // 根据错误类型分类处理
    if (err.name === "ValidationError") {
      ctx.status = 400;
      ctx.body = {
        error: {
          type: "validation",
          message: err.message,
          details: err.details,
        },
      };
    } else if (err.name === "UnauthorizedError") {
      ctx.status = 401;
      ctx.body = {
        error: {
          type: "authentication",
          message: "Authentication required",
        },
      };
    } else if (err.name === "ForbiddenError") {
      ctx.status = 403;
      ctx.body = {
        error: {
          type: "authorization",
          message: "Insufficient permissions",
        },
      };
    } else {
      // 默认处理
      ctx.status = err.status || 500;
      ctx.body = {
        error: {
          type: "server",
          message: err.message,
        },
      };
    }
  }
}

性能优化策略

减少不必要的计算

// ❌ 每次都执行重计算
async function expensiveMiddleware(ctx, next) {
  // 每次请求都重新计算
  const data = await expensiveDatabaseQuery();
  ctx.state.data = data;
  await next();
}
 
// ✅ 使用缓存
const cache = new Map();
const CACHE_TTL = 60000; // 1 分钟
 
async function cachedMiddleware(ctx, next) {
  const cacheKey = ctx.url;
  const cached = cache.get(cacheKey);
 
  if (cached && Date.now() - cached.timestamp < CACHE_TTL) {
    ctx.state.data = cached.data;
    await next();
    return;
  }
 
  const data = await expensiveDatabaseQuery();
  cache.set(cacheKey, { data, timestamp: Date.now() });
  ctx.state.data = data;
  await next();
}

条件执行优化

// ❌ 无条件执行所有逻辑
async function inefficientMiddleware(ctx, next) {
  // 即使不需要,也会执行这些操作
  const user = await getUser(ctx.state.userId);
  const settings = await getSettings(ctx.state.userId);
  const permissions = await getPermissions(ctx.state.userId);
 
  ctx.state.user = user;
  ctx.state.settings = settings;
  ctx.state.permissions = permissions;
 
  await next();
}
 
// ✅ 按需加载
async function efficientMiddleware(ctx, next) {
  // 只在需要时加载数据
  if (ctx.path.startsWith("/api/protected")) {
    ctx.state.user = await getUser(ctx.state.userId);
  }
 
  if (ctx.path.startsWith("/api/admin")) {
    ctx.state.permissions = await getPermissions(ctx.state.userId);
  }
 
  await next();
}

批量处理

// ❌ 串行处理
async function slowMiddleware(ctx, next) {
  ctx.state.user = await getUser(ctx.state.userId);
  ctx.state.settings = await getSettings(ctx.state.userId);
  ctx.state.permissions = await getPermissions(ctx.state.userId);
  await next();
}
 
// ✅ 并行处理
async function fastMiddleware(ctx, next) {
  const [user, settings, permissions] = await Promise.all([
    getUser(ctx.state.userId),
    getSettings(ctx.state.userId),
    getPermissions(ctx.state.userId),
  ]);
 
  ctx.state.user = user;
  ctx.state.settings = settings;
  ctx.state.permissions = permissions;
 
  await next();
}

使用流处理大文件

const fs = require("fs");
const path = require("path");
 
async function fileStreamMiddleware(ctx, next) {
  if (ctx.method !== "GET" || !ctx.query.file) {
    await next();
    return;
  }
 
  const filePath = path.join(__dirname, "uploads", ctx.query.file);
 
  // ✅ 使用流,而不是一次性加载到内存
  ctx.type = path.extname(filePath);
  ctx.body = fs.createReadStream(filePath);
 
  // 不需要调用 next(),因为已经设置了响应
}

安全性最佳实践

输入验证

const Joi = require("joi");
 
function validate(schema, source = "body") {
  return async function validator(ctx, next) {
    try {
      const data =
        source === "body"
          ? ctx.request.body
          : source === "query"
          ? ctx.query
          : ctx.params;
 
      const value = await schema.validateAsync(data, {
        abortEarly: false,
        stripUnknown: true, // 移除未知字段
      });
 
      // 使用验证后的数据
      if (source === "body") {
        ctx.request.body = value;
      } else if (source === "query") {
        ctx.query = value;
      } else {
        ctx.params = value;
      }
 
      await next();
    } catch (err) {
      if (err.isJoi) {
        ctx.throw(400, {
          message: "Validation failed",
          errors: err.details.map((detail) => ({
            field: detail.path.join("."),
            message: detail.message,
          })),
        });
      }
      throw err;
    }
  };
}
 
// 使用
const userSchema = Joi.object({
  name: Joi.string().required().max(100),
  email: Joi.string().email().required(),
  age: Joi.number().integer().min(0).max(150),
});
 
app.use(validate(userSchema, "body"));

防止 SQL 注入

const db = require("./db");
 
// ❌ 危险:SQL 注入
async function unsafeQuery(ctx, next) {
  const userId = ctx.params.id;
  const user = await db.query(`SELECT * FROM users WHERE id = ${userId}`);
  ctx.body = user;
}
 
// ✅ 安全:使用参数化查询
async function safeQuery(ctx, next) {
  const userId = ctx.params.id;
  const user = await db.query("SELECT * FROM users WHERE id = ?", [userId]);
  ctx.body = user;
}

防止 XSS 攻击

const escapeHtml = require("escape-html");
 
async function xssProtection(ctx, next) {
  await next();
 
  // 如果响应是 HTML,需要转义
  if (ctx.type === "text/html" && typeof ctx.body === "string") {
    // 对于用户输入的内容进行转义
    ctx.body = escapeHtml(ctx.body);
  }
 
  // 设置安全头
  ctx.set("X-Content-Type-Options", "nosniff");
  ctx.set("X-Frame-Options", "DENY");
  ctx.set("X-XSS-Protection", "1; mode=block");
}

速率限制

const rateLimit = require("koa-ratelimit");
 
// 基于 IP 的速率限制
const limiter = rateLimit({
  driver: "memory",
  db: new Map(),
  duration: 60000, // 1 分钟
  errorMessage: "Too many requests, please try again later",
  id: (ctx) => ctx.ip,
  headers: {
    remaining: "Rate-Limit-Remaining",
    reset: "Rate-Limit-Reset",
    total: "Rate-Limit-Total",
  },
  max: 100, // 每分钟最多 100 个请求
  disableHeader: false,
});
 
app.use(limiter);
 
// 更细粒度的速率限制
function createRateLimit(options = {}) {
  const { max = 10, window = 60000, key = (ctx) => ctx.ip } = options;
  const store = new Map();
 
  return async function rateLimiter(ctx, next) {
    const identifier = key(ctx);
    const now = Date.now();
    const record = store.get(identifier) || { count: 0, reset: now + window };
 
    if (now > record.reset) {
      record.count = 0;
      record.reset = now + window;
    }
 
    if (record.count >= max) {
      ctx.status = 429;
      ctx.body = {
        error: {
          message: "Too many requests",
          retryAfter: Math.ceil((record.reset - now) / 1000),
        },
      };
      return;
    }
 
    record.count++;
    store.set(identifier, record);
 
    ctx.set("X-RateLimit-Limit", max);
    ctx.set("X-RateLimit-Remaining", max - record.count);
    ctx.set("X-RateLimit-Reset", Math.ceil(record.reset / 1000));
 
    await next();
  };
}

可测试性设计

依赖注入

// ❌ 紧耦合,难以测试
async function hardToTestMiddleware(ctx, next) {
  const db = require("./db"); // 硬编码依赖
  const user = await db.query("SELECT * FROM users WHERE id = ?", [
    ctx.params.id,
  ]);
  ctx.body = user;
  await next();
}
 
// ✅ 可测试:依赖注入
function createUserMiddleware(deps = {}) {
  const { db = require("./db") } = deps;
 
  return async function userMiddleware(ctx, next) {
    const user = await db.query("SELECT * FROM users WHERE id = ?", [
      ctx.params.id,
    ]);
    ctx.body = user;
    await next();
  };
}
 
// 使用
const userMiddleware = createUserMiddleware();
app.use(userMiddleware);
 
// 测试时注入 mock
const mockDb = {
  query: jest.fn().mockResolvedValue({ id: 1, name: "Test" }),
};
const testMiddleware = createUserMiddleware({ db: mockDb });

纯函数设计

// ✅ 纯函数,易于测试
function calculateDiscount(price, discountPercent) {
  return price * (1 - discountPercent / 100);
}
 
async function discountMiddleware(ctx, next) {
  if (ctx.state.user?.isVIP) {
    ctx.state.discount = calculateDiscount(ctx.state.price, 20);
  }
  await next();
}
 
// 测试
test("calculateDiscount should calculate correctly", () => {
  expect(calculateDiscount(100, 20)).toBe(80);
  expect(calculateDiscount(100, 10)).toBe(90);
});

中间件隔离测试

// middleware/auth.test.js
const auth = require("./auth");
 
describe("Auth Middleware", () => {
  let ctx;
  let next;
 
  beforeEach(() => {
    ctx = {
      headers: {},
      state: {},
      throw: jest.fn(),
    };
    next = jest.fn().mockResolvedValue();
  });
 
  test("should throw 401 if no token", async () => {
    await auth(ctx, next);
 
    expect(ctx.throw).toHaveBeenCalledWith(401, "Unauthorized");
    expect(next).not.toHaveBeenCalled();
  });
 
  test("should call next if token is valid", async () => {
    ctx.headers.authorization = "Bearer valid-token";
    ctx.state.user = { id: 1, name: "Test" };
 
    await auth(ctx, next);
 
    expect(next).toHaveBeenCalled();
    expect(ctx.throw).not.toHaveBeenCalled();
  });
});

日志记录最佳实践

结构化日志

const winston = require("winston");
 
const logger = winston.createLogger({
  format: winston.format.json(),
  transports: [new winston.transports.Console()],
});
 
async function structuredLogger(ctx, next) {
  const start = Date.now();
  const requestId = ctx.state.requestId || uuid();
 
  const logData = {
    requestId,
    method: ctx.method,
    url: ctx.url,
    ip: ctx.ip,
    userAgent: ctx.headers["user-agent"],
  };
 
  logger.info("Request started", logData);
 
  try {
    await next();
  } catch (err) {
    logger.error("Request failed", {
      ...logData,
      error: {
        message: err.message,
        stack: err.stack,
        status: err.status,
      },
    });
    throw err;
  }
 
  const duration = Date.now() - start;
  logger.info("Request completed", {
    ...logData,
    status: ctx.status,
    duration,
  });
}

敏感信息过滤

function sanitizeForLogging(
  obj,
  sensitiveKeys = ["password", "token", "secret"]
) {
  const sanitized = { ...obj };
 
  for (const key of sensitiveKeys) {
    if (sanitized[key]) {
      sanitized[key] = "***REDACTED***";
    }
  }
 
  return sanitized;
}
 
async function safeLogger(ctx, next) {
  const start = Date.now();
 
  // 记录请求(过滤敏感信息)
  const requestBody = sanitizeForLogging(ctx.request.body);
  logger.info("Request", {
    method: ctx.method,
    url: ctx.url,
    body: requestBody,
  });
 
  await next();
 
  // 记录响应
  logger.info("Response", {
    status: ctx.status,
    duration: Date.now() - start,
  });
}

实际案例:完整的中间件栈

const Koa = require("koa");
const app = new Koa();
 
// 1. 错误处理
app.use(async (ctx, next) => {
  try {
    await next();
  } catch (err) {
    ctx.status = err.status || 500;
    ctx.body = {
      error: {
        code: err.code || "INTERNAL_ERROR",
        message: err.message,
      },
    };
    ctx.app.emit("error", err, ctx);
  }
});
 
// 2. 安全头
app.use(async (ctx, next) => {
  ctx.set("X-Content-Type-Options", "nosniff");
  ctx.set("X-Frame-Options", "DENY");
  ctx.set("X-XSS-Protection", "1; mode=block");
  ctx.set("Strict-Transport-Security", "max-age=31536000");
  await next();
});
 
// 3. 请求 ID
app.use(async (ctx, next) => {
  ctx.state.requestId = uuid();
  ctx.set("X-Request-ID", ctx.state.requestId);
  await next();
});
 
// 4. 日志
app.use(structuredLogger);
 
// 5. 速率限制
app.use(createRateLimit({ max: 100, window: 60000 }));
 
// 6. 体解析(仅对特定路由)
app.use(bodyParser());
 
// 7. 认证
app.use(createAuthMiddleware({ required: false }));
 
// 8. 路由
app.use(router.routes());
app.use(router.allowedMethods());
 
app.listen(3000);

编写高质量中间件的关键在于遵循设计原则、注重安全性和性能,并保持代码的可测试性。通过本文的实践,你可以构建出健壮、高效的 Koa 应用。