security: add helmet, rate limiting, strict CORS, input sanitization.
- Add helmet for secure HTTP response headers. - Add express-rate-limit: 60 req/min general, 20 req/min on LLM endpoints. - Restrict CORS to localhost origins in dev, CORS_ORIGIN env var in prod. - Cap request body at 16kb. - Add sanitizeText() to strip control chars on all string inputs. - Add isValidStandardId() regex guard on :id param and standard_id fields. - All route handlers use sanitized values; no raw req.body/req.query access.
This commit is contained in:
+97
-36
@@ -1,9 +1,11 @@
|
||||
require("dotenv").config();
|
||||
|
||||
const express = require("express");
|
||||
const cors = require("cors");
|
||||
const path = require("path");
|
||||
const fs = require("fs");
|
||||
const express = require("express");
|
||||
const cors = require("cors");
|
||||
const helmet = require("helmet");
|
||||
const rateLimit = require("express-rate-limit");
|
||||
const path = require("path");
|
||||
const fs = require("fs");
|
||||
|
||||
const { generateExplanation, answerQuestion, rewriteQuery } = require("./services/llmService");
|
||||
const { retrieve } = require("./services/retrieverService");
|
||||
@@ -20,8 +22,50 @@ if (!process.env.GROQ_API_KEY) {
|
||||
);
|
||||
}
|
||||
|
||||
app.use(cors());
|
||||
app.use(express.json());
|
||||
// ── Security headers ─────────────────────────────────────────────────────────
|
||||
|
||||
app.use(helmet());
|
||||
|
||||
// ── CORS — restrict to configured origin or localhost dev ────────────────────
|
||||
|
||||
const ALLOWED_ORIGINS = process.env.CORS_ORIGIN
|
||||
? process.env.CORS_ORIGIN.split(",").map((o) => o.trim())
|
||||
: ["http://localhost:5173", "http://localhost:4173", `http://localhost:${PORT}`];
|
||||
|
||||
app.use(cors({
|
||||
origin: (origin, cb) => {
|
||||
// Allow non-browser requests (curl, server-to-server) and configured origins
|
||||
if (!origin || ALLOWED_ORIGINS.includes(origin)) return cb(null, true);
|
||||
cb(new Error(`CORS: origin ${origin} not allowed`));
|
||||
},
|
||||
methods: ["GET", "POST"],
|
||||
allowedHeaders: ["Content-Type"],
|
||||
}));
|
||||
|
||||
// ── Rate limiting ─────────────────────────────────────────────────────────────
|
||||
|
||||
const apiLimiter = rateLimit({
|
||||
windowMs: 60 * 1000,
|
||||
max: 60,
|
||||
standardHeaders: true,
|
||||
legacyHeaders: false,
|
||||
message: { error: "Too many requests. Please wait a moment and try again." },
|
||||
});
|
||||
|
||||
const llmLimiter = rateLimit({
|
||||
windowMs: 60 * 1000,
|
||||
max: 20,
|
||||
standardHeaders: true,
|
||||
legacyHeaders: false,
|
||||
message: { error: "AI request limit reached. Please wait before trying again." },
|
||||
});
|
||||
|
||||
app.use("/api/", apiLimiter);
|
||||
app.use("/api/recommend", llmLimiter);
|
||||
app.use("/api/ask", llmLimiter);
|
||||
app.use("/api/chat", llmLimiter);
|
||||
|
||||
app.use(express.json({ limit: "16kb" }));
|
||||
|
||||
// ── Load data ───────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -55,6 +99,21 @@ for (const c of chunks) {
|
||||
chunksByStd[c.standard_id].push(c);
|
||||
}
|
||||
|
||||
// ── Input sanitization ────────────────────────────────────────────────────────
|
||||
|
||||
const CONTROL_CHAR_RE = /[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/g;
|
||||
|
||||
function sanitizeText(value, maxLen = 500) {
|
||||
if (typeof value !== "string") return "";
|
||||
return value.replace(CONTROL_CHAR_RE, "").slice(0, maxLen).trim();
|
||||
}
|
||||
|
||||
// standard_id must match IS identifier pattern: letters/digits/spaces/colons/parens/dots/hyphens
|
||||
const STANDARD_ID_RE = /^[A-Za-z0-9 :()./-]{1,60}$/;
|
||||
function isValidStandardId(id) {
|
||||
return typeof id === "string" && STANDARD_ID_RE.test(id.trim());
|
||||
}
|
||||
|
||||
// ── Structured logger ───────────────────────────────────────────────────────
|
||||
|
||||
function log(endpoint, data) {
|
||||
@@ -115,9 +174,10 @@ function bestChunk(standardId, question) {
|
||||
|
||||
// ── GET /api/standards ──────────────────────────────────────────────────────
|
||||
app.get("/api/standards", (req, res) => {
|
||||
const { q = "", category = "", page = "1", limit = "20" } = req.query;
|
||||
const pageNum = Math.max(1, parseInt(page));
|
||||
const limitNum = Math.min(100, Math.max(1, parseInt(limit)));
|
||||
const q = sanitizeText(req.query.q || "", 200);
|
||||
const category = sanitizeText(req.query.category || "", 100);
|
||||
const pageNum = Math.max(1, parseInt(req.query.page) || 1);
|
||||
const limitNum = Math.min(100, Math.max(1, parseInt(req.query.limit) || 20));
|
||||
|
||||
let results = standards;
|
||||
if (category) results = results.filter((s) => s.category === category);
|
||||
@@ -139,9 +199,12 @@ app.get("/api/standards", (req, res) => {
|
||||
|
||||
// ── GET /api/standards/:id ──────────────────────────────────────────────────
|
||||
app.get("/api/standards/:id", (req, res) => {
|
||||
const id = decodeURIComponent(req.params.id);
|
||||
const standard = standardsById[id];
|
||||
if (!standard) return res.status(404).json({ error: "Standard not found" });
|
||||
const raw = decodeURIComponent(req.params.id);
|
||||
if (!isValidStandardId(raw)) {
|
||||
return res.status(400).json({ error: "Invalid standard ID format." });
|
||||
}
|
||||
const standard = standardsById[raw.trim()];
|
||||
if (!standard) return res.status(404).json({ error: "Standard not found." });
|
||||
res.json(standard);
|
||||
});
|
||||
|
||||
@@ -175,19 +238,19 @@ app.get("/api/stats", (req, res) => {
|
||||
* Output: { standards, latency: { retrieval_ms, llm_ms, total_ms } }
|
||||
*/
|
||||
app.post("/api/recommend", async (req, res) => {
|
||||
const { query, top_n = 5, rewrite = false } = req.body;
|
||||
const rawQuery = req.body?.query;
|
||||
const top_n = Math.min(10, Math.max(1, parseInt(req.body?.top_n) || 5));
|
||||
const rewrite = req.body?.rewrite === true;
|
||||
|
||||
if (!query || typeof query !== "string" || !query.trim()) {
|
||||
return res.status(400).json({ error: "query is required." });
|
||||
}
|
||||
if (query.length > 500) {
|
||||
return res.status(400).json({ error: "query must be 500 characters or fewer." });
|
||||
const query = sanitizeText(rawQuery, 500);
|
||||
if (!query) {
|
||||
return res.status(400).json({ error: "query is required and must be a non-empty string." });
|
||||
}
|
||||
|
||||
const t0 = Date.now();
|
||||
|
||||
// Step 1 — Optional query rewrite (fires concurrently, falls back silently)
|
||||
let effectiveQuery = query.trim();
|
||||
let effectiveQuery = query;
|
||||
if (rewrite && process.env.GROQ_API_KEY) {
|
||||
effectiveQuery = await rewriteQuery(query.trim()); // never throws
|
||||
}
|
||||
@@ -196,7 +259,7 @@ app.post("/api/recommend", async (req, res) => {
|
||||
let retrievalResult;
|
||||
const tRetStart = Date.now();
|
||||
try {
|
||||
retrievalResult = await retrieve(effectiveQuery, Math.min(top_n, 10));
|
||||
retrievalResult = await retrieve(effectiveQuery, top_n);
|
||||
} catch (err) {
|
||||
console.error("[recommend] Retrieval error:", err.message);
|
||||
return res.status(502).json({ error: "Retrieval service unavailable. Please try again." });
|
||||
@@ -264,16 +327,14 @@ app.post("/api/recommend", async (req, res) => {
|
||||
* Output: { answer, source: { standard_id, section, chunk_id } }
|
||||
*/
|
||||
app.post("/api/ask", async (req, res) => {
|
||||
const { question, standard_id } = req.body;
|
||||
const question = sanitizeText(req.body?.question, 500);
|
||||
const standard_id = sanitizeText(req.body?.standard_id, 60);
|
||||
|
||||
if (!question || typeof question !== "string" || !question.trim()) {
|
||||
return res.status(400).json({ error: "question is required." });
|
||||
if (!question) {
|
||||
return res.status(400).json({ error: "question is required and must be a non-empty string." });
|
||||
}
|
||||
if (!standard_id || typeof standard_id !== "string") {
|
||||
return res.status(400).json({ error: "standard_id is required." });
|
||||
}
|
||||
if (question.length > 500) {
|
||||
return res.status(400).json({ error: "question must be 500 characters or fewer." });
|
||||
if (!standard_id || !isValidStandardId(standard_id)) {
|
||||
return res.status(400).json({ error: "standard_id is required and must be a valid IS identifier." });
|
||||
}
|
||||
|
||||
const t0 = Date.now();
|
||||
@@ -284,7 +345,7 @@ app.post("/api/ask", async (req, res) => {
|
||||
}
|
||||
|
||||
const tLlm = Date.now();
|
||||
const answer = await answerQuestion(question.trim(), chunk.text); // never throws
|
||||
const answer = await answerQuestion(question, chunk.text); // never throws
|
||||
const llmMs = Date.now() - tLlm;
|
||||
const totalMs = Date.now() - t0;
|
||||
|
||||
@@ -316,16 +377,16 @@ app.post("/api/chat", async (req, res) => {
|
||||
return res.status(503).json({ error: "AI features are not configured on this server." });
|
||||
}
|
||||
|
||||
const { standard_id, question } = req.body;
|
||||
const question = sanitizeText(req.body?.question, 500);
|
||||
const standard_id = sanitizeText(req.body?.standard_id || "", 60);
|
||||
|
||||
if (!question || typeof question !== "string" || !question.trim()) {
|
||||
return res.status(400).json({ error: "question is required." });
|
||||
}
|
||||
if (question.length > 500) {
|
||||
return res.status(400).json({ error: "question must be 500 characters or fewer." });
|
||||
if (!question) {
|
||||
return res.status(400).json({ error: "question is required and must be a non-empty string." });
|
||||
}
|
||||
|
||||
const std = standard_id ? standardsById[standard_id] : null;
|
||||
const std = (standard_id && isValidStandardId(standard_id))
|
||||
? standardsById[standard_id] ?? null
|
||||
: null;
|
||||
let chunkText = "";
|
||||
|
||||
if (std) {
|
||||
|
||||
Reference in New Issue
Block a user