Skip to content

Commit 2411b0a

Browse files
authored
Merge pull request #463 from AikidoSec/mark-unsafe
Mark JS values as unsafe
2 parents 10ed01d + 3cc0dd0 commit 2411b0a

File tree

6 files changed

+318
-0
lines changed

6 files changed

+318
-0
lines changed

docs/markUnsafe.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Marking Unsafe Input
2+
3+
To flag input as unsafe, you can use the `markUnsafe` function. This is useful when you want to explicitly label data as potentially dangerous, such as output from an LLM being used to generate a file name. Here's an example using OpenAI's function calling feature:
4+
5+
```js
6+
import Zen from "@aikidosec/firewall";
7+
import OpenAI from "openai";
8+
import { readFile } from "fs/promises";
9+
10+
const openai = new OpenAI();
11+
12+
const completion = await openai.chat.completions.create({
13+
model: "gpt-4",
14+
messages: [
15+
{
16+
role: "user",
17+
content: "Read the contents of the config file"
18+
}
19+
],
20+
tools: [
21+
{
22+
type: "function",
23+
function: {
24+
name: "read_file",
25+
description: "Read the contents of a file",
26+
parameters: {
27+
type: "object",
28+
properties: {
29+
filepath: {
30+
type: "string",
31+
description: "The path to the file to read"
32+
}
33+
},
34+
required: ["filepath"]
35+
}
36+
}
37+
}
38+
]
39+
});
40+
41+
const toolCall = completion.choices[0].message.tool_calls[0];
42+
const filepath = JSON.parse(toolCall.function.arguments).filepath;
43+
44+
// Mark the filepath as unsafe since it came from the LLM
45+
Zen.markUnsafe(filepath);
46+
47+
// This will be blocked if the LLM tries to perform path traversal
48+
// e.g. if filepath is "../../../etc/passwd"
49+
await readFile(filepath);
50+
```
51+
52+
This example shows how to protect against path traversal attacks when using OpenAI's function calling feature. The LLM might try to access sensitive files using path traversal (e.g., `../../../etc/passwd`), but Zen will detect and block these attempts.
53+
54+
You can also pass multiple arguments to `markUnsafe`:
55+
56+
```js
57+
Zen.markUnsafe(a, b, c);
58+
```
59+
60+
You can pass strings, objects, and arrays to `markUnsafe`. Zen will track the marked data across your application and will be able to detect any attacks that may be attempted using the marked data.
61+
62+
## Caveats when marking data as unsafe
63+
64+
⚠️ Be careful when marking data as unsafe, as it may lead to false positives. If you generate a full SQL query using an LLM and mark it as unsafe, Zen will flag all queries using that SQL as an attack.
65+
66+
BAD:
67+
68+
```js
69+
Zen.markUnsafe("SELECT * FROM users WHERE id = '' OR 1=1 -- '");
70+
71+
await db.query("SELECT * FROM users WHERE id = '' OR 1=1 -- '");
72+
```
73+
74+
GOOD:
75+
76+
```js
77+
Zen.markUnsafe("' OR 1=1 -- ");
78+
79+
await db.query("SELECT * FROM users WHERE id = '' OR 1=1 -- '");
80+
```

library/agent/Context.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export type Context = {
2323
graphql?: string[];
2424
xml?: unknown[];
2525
subdomains?: string[]; // https://expressjs.com/en/5x/api.html#req.subdomains
26+
markUnsafe?: unknown[];
2627
cache?: Map<Source, ReturnType<typeof extractStringsFromUserInput>>;
2728
/**
2829
* Used to store redirects in outgoing http(s) requests that are started by a user-supplied input (hostname and port / url) to prevent SSRF redirect attacks.
@@ -84,6 +85,7 @@ export function runWithContext<T>(context: Context, fn: () => T) {
8485
current.xml = context.xml;
8586
current.subdomains = context.subdomains;
8687
current.outgoingRequestRedirects = context.outgoingRequestRedirects;
88+
current.markUnsafe = context.markUnsafe;
8789

8890
// Clear all the cached user input strings
8991
delete current.cache;

library/agent/Source.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export const SOURCES = [
77
"graphql",
88
"xml",
99
"subdomains",
10+
"markUnsafe",
1011
] as const;
1112

1213
export type Source = (typeof SOURCES)[number];
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import * as t from "tap";
2+
import { createTestAgent } from "../../helpers/createTestAgent";
3+
import { wrap } from "../../helpers/wrap";
4+
import { checkContextForSqlInjection } from "../../vulnerabilities/sql-injection/checkContextForSqlInjection";
5+
import { SQLDialectPostgres } from "../../vulnerabilities/sql-injection/dialects/SQLDialectPostgres";
6+
import { Context, getContext, runWithContext } from "../Context";
7+
import { markUnsafe } from "./markUnsafe";
8+
9+
function createContext(): Context {
10+
return {
11+
remoteAddress: "::1",
12+
method: "POST",
13+
url: "http://localhost:4000",
14+
query: {},
15+
headers: {},
16+
body: {
17+
image: "http://localhost:4000/api/internal",
18+
},
19+
cookies: {},
20+
routeParams: {},
21+
source: "express",
22+
route: "/posts/:id",
23+
};
24+
}
25+
26+
t.test("it works", async () => {
27+
const agent = createTestAgent({});
28+
agent.start([]);
29+
30+
// No unsafe input
31+
runWithContext(createContext(), () => {
32+
const context = getContext();
33+
34+
if (!context) {
35+
throw new Error("Context is not defined");
36+
}
37+
38+
const result = checkContextForSqlInjection({
39+
sql: 'SELECT * FROM "users" WHERE id = 1',
40+
operation: "pg.query",
41+
dialect: new SQLDialectPostgres(),
42+
context: context,
43+
});
44+
t.same(result, undefined);
45+
});
46+
47+
// Unsafe string
48+
runWithContext(createContext(), () => {
49+
markUnsafe("id = 1");
50+
51+
const context = getContext();
52+
53+
if (!context) {
54+
throw new Error("Context is not defined");
55+
}
56+
57+
const result = checkContextForSqlInjection({
58+
sql: 'SELECT * FROM "users" WHERE id = 1',
59+
operation: "pg.query",
60+
dialect: new SQLDialectPostgres(),
61+
context: context,
62+
});
63+
t.same(result, {
64+
kind: "sql_injection",
65+
operation: "pg.query",
66+
source: "markUnsafe",
67+
metadata: {
68+
sql: 'SELECT * FROM "users" WHERE id = 1',
69+
},
70+
payload: "id = 1",
71+
pathsToPayload: [".[0]"],
72+
});
73+
});
74+
75+
// Unsafe object
76+
runWithContext(createContext(), () => {
77+
markUnsafe({ somePropertyThatContainsSQL: "id = 1" });
78+
79+
const context = getContext();
80+
81+
if (!context) {
82+
throw new Error("Context is not defined");
83+
}
84+
85+
const result = checkContextForSqlInjection({
86+
sql: 'SELECT * FROM "users" WHERE id = 1',
87+
operation: "pg.query",
88+
dialect: new SQLDialectPostgres(),
89+
context: context,
90+
});
91+
t.same(result, {
92+
kind: "sql_injection",
93+
operation: "pg.query",
94+
source: "markUnsafe",
95+
metadata: {
96+
sql: 'SELECT * FROM "users" WHERE id = 1',
97+
},
98+
payload: "id = 1",
99+
pathsToPayload: [".[0].somePropertyThatContainsSQL"],
100+
});
101+
});
102+
103+
// Test markUnsafe called without context
104+
const logs: string[] = [];
105+
wrap(console, "warn", function warn() {
106+
return function warn(message: string) {
107+
logs.push(message);
108+
};
109+
});
110+
markUnsafe("id = 1");
111+
t.same(logs, [
112+
"markUnsafe(...) was called without a context. The data will not be tracked. Make sure to call markUnsafe(...) within an HTTP request. If you're using serverless functions, make sure to use the handler wrapper provided by Zen.",
113+
]);
114+
115+
// Warning logged only once
116+
markUnsafe("id = 1");
117+
t.same(logs.length, 1);
118+
119+
// Test if serialize fails
120+
runWithContext(createContext(), () => {
121+
// Define an object with a circular reference
122+
const obj: Record<string, any> = {};
123+
obj.self = obj;
124+
markUnsafe(obj);
125+
});
126+
t.same(logs, [
127+
"markUnsafe(...) was called without a context. The data will not be tracked. Make sure to call markUnsafe(...) within an HTTP request. If you're using serverless functions, make sure to use the handler wrapper provided by Zen.",
128+
"markUnsafe(...) failed to serialize the data",
129+
]);
130+
131+
runWithContext(createContext(), () => {
132+
markUnsafe();
133+
});
134+
t.same(logs, [
135+
"markUnsafe(...) was called without a context. The data will not be tracked. Make sure to call markUnsafe(...) within an HTTP request. If you're using serverless functions, make sure to use the handler wrapper provided by Zen.",
136+
"markUnsafe(...) failed to serialize the data",
137+
"markUnsafe(...) was called without any data.",
138+
]);
139+
140+
runWithContext(createContext(), () => {
141+
markUnsafe(1, true, null, undefined, () => {}, Symbol("test"));
142+
});
143+
t.same(logs, [
144+
"markUnsafe(...) was called without a context. The data will not be tracked. Make sure to call markUnsafe(...) within an HTTP request. If you're using serverless functions, make sure to use the handler wrapper provided by Zen.",
145+
"markUnsafe(...) failed to serialize the data",
146+
"markUnsafe(...) was called without any data.",
147+
"markUnsafe(...) expects an object, array, or string. Received: number",
148+
"markUnsafe(...) expects an object, array, or string. Received: boolean",
149+
"markUnsafe(...) expects an object, array, or string. Received: null",
150+
"markUnsafe(...) expects an object, array, or string. Received: undefined",
151+
"markUnsafe(...) expects an object, array, or string. Received: function",
152+
"markUnsafe(...) expects an object, array, or string. Received: symbol",
153+
]);
154+
});

library/agent/context/markUnsafe.ts

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import { isPlainObject } from "../../helpers/isPlainObject";
2+
import { getInstance } from "../AgentSingleton";
3+
import { Context, updateContext } from "../Context";
4+
import { ContextStorage } from "./ContextStorage";
5+
6+
export function markUnsafe(...data: unknown[]) {
7+
const agent = getInstance();
8+
9+
if (!agent) {
10+
return;
11+
}
12+
13+
const context = ContextStorage.getStore();
14+
15+
if (!context) {
16+
logWarningMarkUnsafeWithoutContext();
17+
return;
18+
}
19+
20+
if (data.length === 0) {
21+
// eslint-disable-next-line no-console
22+
console.warn("markUnsafe(...) was called without any data.");
23+
}
24+
25+
for (const item of data) {
26+
if (
27+
!isPlainObject(item) &&
28+
!Array.isArray(item) &&
29+
typeof item !== "string"
30+
) {
31+
const type = item === null ? "null" : typeof item;
32+
// eslint-disable-next-line no-console
33+
console.warn(
34+
`markUnsafe(...) expects an object, array, or string. Received: ${type}`
35+
);
36+
continue;
37+
}
38+
39+
addPayloadToContext(context, item);
40+
}
41+
}
42+
43+
function addPayloadToContext(context: Context, payload: unknown) {
44+
try {
45+
const current = context.markUnsafe || [];
46+
const a = JSON.stringify(payload);
47+
48+
if (
49+
!current.some((item) => {
50+
// JSON.stringify is used to compare objects
51+
// without having to copy a deep equality function
52+
return JSON.stringify(item) === a;
53+
})
54+
) {
55+
current.push(payload);
56+
updateContext(context, "markUnsafe", current);
57+
}
58+
} catch (e: unknown) {
59+
if (e instanceof Error) {
60+
// eslint-disable-next-line no-console
61+
console.warn("markUnsafe(...) failed to serialize the data");
62+
}
63+
}
64+
}
65+
66+
let loggedWarningMarkUnsafeWithoutContext = false;
67+
68+
function logWarningMarkUnsafeWithoutContext() {
69+
if (loggedWarningMarkUnsafeWithoutContext) {
70+
return;
71+
}
72+
73+
// eslint-disable-next-line no-console
74+
console.warn(
75+
"markUnsafe(...) was called without a context. The data will not be tracked. Make sure to call markUnsafe(...) within an HTTP request. If you're using serverless functions, make sure to use the handler wrapper provided by Zen."
76+
);
77+
78+
loggedWarningMarkUnsafeWithoutContext = true;
79+
}

library/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import isFirewallSupported from "./helpers/isFirewallSupported";
33
import shouldEnableFirewall from "./helpers/shouldEnableFirewall";
44
import { setUser } from "./agent/context/user";
5+
import { markUnsafe } from "./agent/context/markUnsafe";
56
import { shouldBlockRequest } from "./middleware/shouldBlockRequest";
67
import { addExpressMiddleware } from "./middleware/express";
78
import { addHonoMiddleware } from "./middleware/hono";
@@ -18,6 +19,7 @@ if (supported && shouldEnable) {
1819

1920
export {
2021
setUser,
22+
markUnsafe,
2123
shouldBlockRequest,
2224
addExpressMiddleware,
2325
addHonoMiddleware,

0 commit comments

Comments
 (0)