-
-
Save ikupenov/10bc89d92d92eaba8cc5569013e04069 to your computer and use it in GitHub Desktop.
import { and, type DBQueryConfig, eq, type SQLWrapper } from "drizzle-orm"; | |
import { drizzle } from "drizzle-orm/postgres-js"; | |
import postgres, { type Sql } from "postgres"; | |
import { type AnyArgs } from "@/common"; | |
import { | |
type DbClient, | |
type DbTable, | |
type DeleteArgs, | |
type DeleteFn, | |
type FindArgs, | |
type FindFn, | |
type FromArgs, | |
type FromFn, | |
type InsertArgs, | |
type JoinArgs, | |
type JoinFn, | |
type Owner, | |
type RlsDbClient, | |
type SetArgs, | |
type SetFn, | |
type UpdateArgs, | |
type ValuesArgs, | |
type ValuesFn, | |
type WhereArgs, | |
type WhereFn, | |
} from "./db-client.types"; | |
import * as schema from "./schema"; | |
export const connectDb = (connectionString: string) => { | |
return postgres(connectionString); | |
}; | |
export const createDbClient = (client: Sql): DbClient => { | |
return drizzle(client, { schema }); | |
}; | |
export const createRlsDbClient = (client: Sql, owner: Owner): RlsDbClient => { | |
const db = createDbClient(client); | |
const ownerIdColumn = "ownerId" as const; | |
// eslint-disable-next-line import/namespace | |
const getTable = (table: DbTable) => schema[table]; | |
const getAccessPolicy = ( | |
table: { | |
// eslint-disable-next-line @typescript-eslint/no-explicit-any | |
[ownerIdColumn]: any; | |
}, | |
owner: Owner, | |
) => eq(table[ownerIdColumn], owner.id); | |
interface InvokeContext { | |
path?: string[]; | |
fnPath?: { name: string; args: unknown[] }[]; | |
} | |
interface InterceptFn { | |
invoke: (...args: unknown[]) => unknown; | |
name: string; | |
args: unknown[]; | |
} | |
interface OverrideFn { | |
pattern: string | string[]; | |
action: () => unknown; | |
} | |
const intercept = (fn: InterceptFn, context: InvokeContext = {}) => { | |
const { path = [], fnPath = [] } = context; | |
const pathAsString = path.join("."); | |
const matchPath = (pattern: string) => { | |
return new RegExp( | |
`^${pattern.replace(/\./g, "\\.").replace(/\*/g, ".*")}$`, | |
).test(pathAsString); | |
}; | |
const overrides: OverrideFn[] = [ | |
{ | |
pattern: ["db.execute", "db.*.execute"], | |
action: () => { | |
throw new Error("'execute' in rls DB is not allowed"); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.query.findMany", | |
"db.query.*.findMany", | |
"db.query.findFirst", | |
"db.query.*.findFirst", | |
], | |
action: () => { | |
const findFn = fn.invoke as FindFn; | |
const findArgs = fn.args as FindArgs; | |
const tableIndex = path.findIndex((x) => x === "query") + 1; | |
const tableName = path[tableIndex]! as keyof typeof db.query; | |
const table = getTable(tableName as DbTable); | |
if (ownerIdColumn in table) { | |
let [config] = findArgs; | |
if (config?.where) { | |
config = { | |
...config, | |
where: and( | |
getAccessPolicy(table, owner), | |
config.where as SQLWrapper, | |
), | |
}; | |
} | |
if (!config?.where) { | |
config = { | |
...config, | |
where: getAccessPolicy(table, owner), | |
}; | |
} | |
if (config.with) { | |
config = { | |
...config, | |
with: ( | |
Object.keys(config.with) as (keyof typeof config.with)[] | |
).reduce<DBQueryConfig["with"]>((acc, key) => { | |
const value = config!.with![key] as | |
| true | |
| null | |
| DBQueryConfig<"many">; | |
if (value === true) { | |
return { | |
...acc, | |
[key]: { | |
where: (table) => | |
ownerIdColumn in table | |
? // eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any | |
getAccessPolicy(table as any, owner) | |
: undefined, | |
}, | |
}; | |
} | |
if (typeof value === "object" && value !== null) { | |
return { | |
...acc, | |
[key]: { | |
...value, | |
where: (table, other) => | |
ownerIdColumn in table | |
? and( | |
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any | |
getAccessPolicy(table as any, owner), | |
typeof value.where === "function" | |
? value.where(table, other) | |
: value.where, | |
) | |
: typeof value.where === "function" | |
? value.where(table, other) | |
: value.where, | |
}, | |
}; | |
} | |
return { ...acc, [key]: value }; | |
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/prefer-reduce-type-parameter, @typescript-eslint/no-explicit-any | |
}, config.with as any), | |
}; | |
} | |
return findFn(...([config] as FindArgs)); | |
} | |
return findFn(...findArgs); | |
}, | |
}, | |
{ | |
pattern: "db.*.from", | |
action: () => { | |
const fromFn = fn.invoke as FromFn; | |
const fromArgs = fn.args as FromArgs; | |
const [table] = fromArgs; | |
if (ownerIdColumn in table) { | |
return fromFn(...fromArgs).where(getAccessPolicy(table, owner)); | |
} | |
return fromFn(...fromArgs); | |
}, | |
}, | |
{ | |
pattern: ["db.*.from.where", "db.*.from.*.where"], | |
action: () => { | |
const whereFn = fn.invoke as WhereFn; | |
const whereArgs = fn.args as WhereArgs; | |
const [table] = fnPath.findLast((x) => x.name === "from") | |
?.args as FromArgs; | |
if (ownerIdColumn in table) { | |
const [whereFilter] = whereArgs; | |
return whereFn( | |
and(getAccessPolicy(table, owner), whereFilter as SQLWrapper), | |
); | |
} | |
return whereFn(...whereArgs); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.*.leftJoin", | |
"db.*.rightJoin", | |
"db.*.innerJoin", | |
"db.*.fullJoin", | |
], | |
action: () => { | |
const joinFn = fn.invoke as JoinFn; | |
const joinArgs = fn.args as JoinArgs; | |
const [table, joinOptions] = joinArgs; | |
if (ownerIdColumn in table) { | |
return joinFn( | |
table, | |
and(getAccessPolicy(table, owner), joinOptions as SQLWrapper), | |
); | |
} | |
return joinFn(...joinArgs); | |
}, | |
}, | |
{ | |
pattern: "db.insert.values", | |
action: () => { | |
const valuesFn = fn.invoke as ValuesFn; | |
const valuesArgs = fn.args as ValuesArgs; | |
const [table] = fnPath.findLast((x) => x.name === "insert") | |
?.args as InsertArgs; | |
if (ownerIdColumn in table) { | |
let [valuesToInsert] = valuesArgs; | |
if (!Array.isArray(valuesToInsert)) { | |
valuesToInsert = [valuesToInsert]; | |
} | |
const valuesToInsertWithOwner = valuesToInsert.map((value) => ({ | |
...value, | |
ownerId: owner.id, | |
})); | |
return valuesFn(valuesToInsertWithOwner); | |
} | |
return valuesFn(...valuesArgs); | |
}, | |
}, | |
{ | |
pattern: "db.update.set", | |
action: () => { | |
const setFn = fn.invoke as SetFn; | |
const setArgs = fn.args as SetArgs; | |
const [table] = fnPath.findLast((x) => x.name === "update") | |
?.args as UpdateArgs; | |
if (ownerIdColumn in table) { | |
return setFn(...setArgs).where(getAccessPolicy(table, owner)); | |
} | |
return setFn(...setArgs); | |
}, | |
}, | |
{ | |
pattern: ["db.update.where", "db.update.*.where"], | |
action: () => { | |
const whereFn = fn.invoke as WhereFn; | |
const whereArgs = fn.args as WhereArgs; | |
const [table] = [...fnPath].reverse().find((x) => x.name === "update") | |
?.args as UpdateArgs; | |
if (ownerIdColumn in table) { | |
const [whereFilter] = whereArgs; | |
return whereFn( | |
and(getAccessPolicy(table, owner), whereFilter as SQLWrapper), | |
); | |
} | |
return whereFn(...whereArgs); | |
}, | |
}, | |
{ | |
pattern: "db.delete", | |
action: () => { | |
const deleteFn = fn.invoke as DeleteFn; | |
const deleteArgs = fn.args as DeleteArgs; | |
const [table] = deleteArgs; | |
if (ownerIdColumn in table) { | |
return deleteFn(...deleteArgs).where(getAccessPolicy(table, owner)); | |
} | |
return deleteFn(...deleteArgs); | |
}, | |
}, | |
{ | |
pattern: ["db.delete.where", "db.delete.*.where"], | |
action: () => { | |
const whereFn = fn.invoke as WhereFn; | |
const whereArgs = fn.args as WhereArgs; | |
const [table] = fnPath.findLast((x) => x.name === "delete") | |
?.args as DeleteArgs; | |
if (ownerIdColumn in table) { | |
const [whereOptions] = whereArgs; | |
return whereFn( | |
and(getAccessPolicy(table, owner), whereOptions as SQLWrapper), | |
); | |
} | |
return whereFn(...whereArgs); | |
}, | |
}, | |
]; | |
const fnOverride = overrides.find(({ pattern, action }) => { | |
if (Array.isArray(pattern) && pattern.some(matchPath)) { | |
return action; | |
} | |
if (typeof pattern === "string" && matchPath(pattern)) { | |
return action; | |
} | |
return null; | |
})?.action; | |
return fnOverride ? fnOverride() : fn.invoke(...fn.args); | |
}; | |
const createProxy = <T extends object>( | |
target: T, | |
context: InvokeContext = {}, | |
): T => { | |
const { path = [], fnPath = [] } = context; | |
return new Proxy<T>(target, { | |
get: (innerTarget, innerTargetProp, innerTargetReceiver) => { | |
const currentPath = path.concat(innerTargetProp.toString()); | |
const innerTargetPropValue = Reflect.get( | |
innerTarget, | |
innerTargetProp, | |
innerTargetReceiver, | |
); | |
if (typeof innerTargetPropValue === "function") { | |
return (...args: AnyArgs) => { | |
const currentFnPath = [ | |
...fnPath, | |
{ name: innerTargetProp.toString(), args }, | |
]; | |
const result = intercept( | |
{ | |
invoke: innerTargetPropValue.bind( | |
innerTarget, | |
) as InterceptFn["invoke"], | |
name: innerTargetProp.toString(), | |
args, | |
}, | |
{ path: currentPath, fnPath: currentFnPath }, | |
); | |
if ( | |
typeof result === "object" && | |
result !== null && | |
!Array.isArray(result) | |
) { | |
return createProxy(result, { | |
path: currentPath, | |
fnPath: currentFnPath, | |
}); | |
} | |
return result; | |
}; | |
} else if ( | |
typeof innerTargetPropValue === "object" && | |
innerTargetPropValue !== null && | |
!Array.isArray(innerTargetPropValue) | |
) { | |
// wrap nested objects in a proxy as well | |
return createProxy(innerTargetPropValue, { | |
path: currentPath, | |
fnPath, | |
}); | |
} | |
return innerTargetPropValue; | |
}, | |
}); | |
}; | |
return createProxy(db, { path: ["db"] }); | |
}; |
import { type drizzle } from "drizzle-orm/postgres-js"; | |
import type * as schema from "./schema"; | |
declare const db: ReturnType<typeof drizzle<typeof schema>>; | |
export interface Owner { | |
id: string | null; | |
} | |
export type DbClient = typeof db; | |
export type DbSchema = typeof schema; | |
export type DbTable = keyof DbSchema; | |
export type RlsDbClient = Omit<DbClient, "execute">; | |
export type FindFn<K extends keyof typeof db.query = keyof typeof db.query> = ( | |
...args: | |
| Parameters<(typeof db.query)[K]["findFirst"]> | |
| Parameters<(typeof db.query)[K]["findMany"]> | |
) => | |
| ReturnType<(typeof db.query)[K]["findFirst"]> | |
| ReturnType<(typeof db.query)[K]["findMany"]>; | |
export type FindArgs<K extends keyof typeof db.query = keyof typeof db.query> = | |
Parameters<FindFn<K>>; | |
export type SelectFn = typeof db.select; | |
export type SelectArgs = Parameters<SelectFn>; | |
export type FromFn = ReturnType<SelectFn>["from"]; | |
export type FromArgs = Parameters<FromFn>; | |
export type WhereFn = ReturnType<FromFn>["where"]; | |
export type WhereArgs = Parameters<WhereFn>; | |
export type JoinFn = ReturnType<FromFn>["leftJoin"]; | |
export type JoinArgs = Parameters<JoinFn>; | |
export type InsertFn = typeof db.insert; | |
export type InsertArgs = Parameters<InsertFn>; | |
export type ValuesFn = ReturnType<InsertFn>["values"]; | |
export type ValuesArgs = Parameters<ValuesFn>; | |
export type UpdateFn = typeof db.update; | |
export type UpdateArgs = Parameters<UpdateFn>; | |
export type SetFn = ReturnType<UpdateFn>["set"]; | |
export type SetArgs = Parameters<SetFn>; | |
export type DeleteFn = typeof db.delete; | |
export type DeleteArgs = Parameters<DeleteFn>; |
This is amazing, I got it working too. Did you come up with this yourself?! I hope the Drizzle team will integrate this solution into their promised RLDB support
This is amazing, I got it working too. Did you come up with this yourself?! I hope the Drizzle team will integrate this solution into their promised RLDB support
I think RLDB is a fundamentally different concept? This is automatically adding some "wheres", right?
This is amazing, I got it working too. Did you come up with this yourself?! I hope the Drizzle team will integrate this solution into their promised RLDB support
I think RLDB is a fundamentally different concept? This is automatically adding some "wheres", right?
Yes you are right actually, I misunderstood it. It's more "automatic scoping" or "implicit filtering" or something? How would you call it?
I copied the code, and it is miraculously working for my mysql setup after some modifications.
However, It is not working with transactions.
db.transaction(async (tx) => {
//...tx operations like tx.select will not go through the intercept function
});
If someone finds a solution, would appreciate it
I was able to make transactions work!, with a promising solution.
I just updated all of the patterns to also search for .tx calls, and then I added one more pattern which matches for db.transaction
and overrides it, reusing the proxy function setter and overriding each tx
properties. I can't believe this is actually working 🤯
/* eslint-disable @typescript-eslint/no-unsafe-call */
/* eslint-disable @typescript-eslint/no-unsafe-member-access */
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import type { DBQueryConfig, SQLWrapper } from "drizzle-orm";
import { and, eq } from "drizzle-orm";
import type { db as _db } from "../client";
import type {
DbTable,
DeleteArgs,
DeleteFn,
FindArgs,
FindFn,
FromArgs,
FromFn,
InsertArgs,
JoinArgs,
JoinFn,
SetArgs,
SetFn,
Team,
TeamDbClient,
UpdateArgs,
ValuesArgs,
ValuesFn,
WhereArgs,
WhereFn,
} from "./teamDb.types";
import { db } from "../client";
import * as schema from "../schema";
type AnyArgs = any[];
interface InvokeContext {
path?: string[];
fnPath?: { name: string; args: unknown[] }[];
}
interface InterceptFn {
invoke: (...args: unknown[]) => unknown;
name: string;
args: unknown[];
}
interface OverrideFn {
pattern: string | string[];
action: () => unknown;
}
export const getTeamDb = (team: Team): TeamDbClient => {
const teamIdColumn = "teamId";
const getTable = (table: DbTable) => schema[table];
const getAccessPolicy = (
table: {
[teamIdColumn]: any;
},
owner: Team,
) => eq(table[teamIdColumn], owner.id);
const intercept = (fn: InterceptFn, context: InvokeContext = {}) => {
const { path = [], fnPath = [] } = context;
const pathAsString = path.join(".");
const matchPath = (pattern: string) => {
return new RegExp(
`^${pattern.replace(/\./g, "\\.").replace(/\*/g, ".*")}$`,
).test(pathAsString);
};
const overrides: OverrideFn[] = [
{
pattern: "db.transaction",
action: () => {
const transactionFn = fn.invoke as typeof db.transaction;
const [callback] = fn.args as Parameters<typeof db.transaction>;
return transactionFn(async (tx) => {
const wrappedTx = createProxy(tx, { path: ["tx"] });
return callback(wrappedTx);
});
},
},
{
pattern: ["db.execute", "db.*.execute", "tx.execute", "tx.*.execute"],
action: () => {
throw new Error("'execute' in rls DB is not allowed");
},
},
{
pattern: [
"db.query.findMany",
"db.query.*.findMany",
"db.query.findFirst",
"db.query.*.findFirst",
"tx.query.findMany",
"tx.query.*.findMany",
"tx.query.findFirst",
"tx.query.*.findFirst",
],
action: () => {
const findFn = fn.invoke as FindFn;
const findArgs = fn.args as FindArgs;
const tableIndex = path.findIndex((x) => x === "query") + 1;
const tableName = path[tableIndex]! as keyof typeof db.query;
const table = getTable(tableName as DbTable);
if (teamIdColumn in table) {
let [config] = findArgs;
if (config?.where) {
config = {
...config,
where: and(
getAccessPolicy(table, team),
config.where as SQLWrapper,
),
};
}
if (!config?.where) {
config = {
...config,
where: getAccessPolicy(table, team),
};
}
if (config.with) {
config = {
...config,
with: (
Object.keys(config.with) as (keyof typeof config.with)[]
).reduce<DBQueryConfig["with"]>((acc, key) => {
const value = config!.with![key] as
| true
| null
| DBQueryConfig<"many">;
if (value === true) {
return {
...acc,
[key]: {
where: (table) =>
teamIdColumn in table
? // @ts-expect-error: typescript aint easy
getAccessPolicy(table, team)
: undefined,
},
};
}
if (typeof value === "object" && value !== null) {
return {
...acc,
[key]: {
...value,
where: (table, other) =>
teamIdColumn in table
? and(
// @ts-expect-error: typescript aint easy
getAccessPolicy(table, team),
typeof value.where === "function"
? value.where(table, other)
: value.where,
)
: typeof value.where === "function"
? value.where(table, other)
: value.where,
},
};
}
return { ...acc, [key]: value };
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
}, config.with as any),
};
}
return findFn(...([config] as FindArgs));
}
return findFn(...findArgs);
},
},
{
pattern: ["db.*.from", "tx.*.from"],
action: () => {
const fromFn = fn.invoke as FromFn;
const fromArgs = fn.args as FromArgs;
const [table] = fromArgs;
if (teamIdColumn in table) {
return fromFn(...fromArgs).where(getAccessPolicy(table, team));
}
return fromFn(...fromArgs);
},
},
{
pattern: [
"db.*.from.where",
"db.*.from.*.where",
"tx.*.from.where",
"tx.*.from.*.where",
],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "from")
?.args as FromArgs;
if (teamIdColumn in table) {
const [whereFilter] = whereArgs;
return whereFn(
and(getAccessPolicy(table, team), whereFilter as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
{
pattern: [
"db.*.leftJoin",
"db.*.rightJoin",
"db.*.innerJoin",
"db.*.fullJoin",
"tx.*.leftJoin",
"tx.*.rightJoin",
"tx.*.innerJoin",
"tx.*.fullJoin",
],
action: () => {
const joinFn = fn.invoke as JoinFn;
const joinArgs = fn.args as JoinArgs;
const [table, joinOptions] = joinArgs;
if (teamIdColumn in table) {
return joinFn(
table,
and(getAccessPolicy(table, team), joinOptions as SQLWrapper),
);
}
return joinFn(...joinArgs);
},
},
{
pattern: ["db.insert.values", "tx.insert.values"],
action: () => {
const valuesFn = fn.invoke as ValuesFn;
const valuesArgs = fn.args as ValuesArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "insert")
?.args as InsertArgs;
if (teamIdColumn in table) {
let [valuesToInsert] = valuesArgs;
if (!Array.isArray(valuesToInsert)) {
valuesToInsert = [valuesToInsert];
}
const valuesToInsertWithOwner = valuesToInsert.map((value) => ({
...value,
ownerId: team.id,
}));
return valuesFn(valuesToInsertWithOwner);
}
return valuesFn(...valuesArgs);
},
},
{
pattern: ["db.update.set", "tx.update.set"],
action: () => {
const setFn = fn.invoke as SetFn;
const setArgs = fn.args as SetArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "update")
?.args as UpdateArgs;
if (teamIdColumn in table) {
return setFn(...setArgs).where(getAccessPolicy(table, team));
}
return setFn(...setArgs);
},
},
{
pattern: [
"db.update.where",
"db.update.*.where",
"tx.update.where",
"tx.update.*.where",
],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
const [table] = [...fnPath].reverse().find((x) => x.name === "update")
?.args as UpdateArgs;
if (teamIdColumn in table) {
const [whereFilter] = whereArgs;
return whereFn(
and(getAccessPolicy(table, team), whereFilter as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
{
pattern: ["db.delete", "tx.delete"],
action: () => {
const deleteFn = fn.invoke as DeleteFn;
const deleteArgs = fn.args as DeleteArgs;
const [table] = deleteArgs;
if (teamIdColumn in table) {
return deleteFn(...deleteArgs).where(getAccessPolicy(table, team));
}
return deleteFn(...deleteArgs);
},
},
{
pattern: [
"db.delete.where",
"db.delete.*.where",
"tx.delete.where",
"tx.delete.*.where",
],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "delete")
?.args as DeleteArgs;
if (teamIdColumn in table) {
const [whereOptions] = whereArgs;
return whereFn(
and(getAccessPolicy(table, team), whereOptions as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
];
const fnOverride = overrides.find(({ pattern, action }) => {
if (Array.isArray(pattern) && pattern.some(matchPath)) {
return action;
}
if (typeof pattern === "string" && matchPath(pattern)) {
return action;
}
return null;
})?.action;
return fnOverride ? fnOverride() : fn.invoke(...fn.args);
};
const createProxy = <T extends object>(
target: T,
context: InvokeContext = {},
): T => {
const { path = [], fnPath = [] } = context;
return new Proxy<T>(target, {
get: (innerTarget, innerTargetProp, innerTargetReceiver) => {
const currentPath = path.concat(innerTargetProp.toString());
const innerTargetPropValue = Reflect.get(
innerTarget,
innerTargetProp,
innerTargetReceiver,
);
if (typeof innerTargetPropValue === "function") {
return (...args: AnyArgs) => {
const currentFnPath = [
...fnPath,
{ name: innerTargetProp.toString(), args },
];
const result = intercept(
{
invoke: innerTargetPropValue.bind(
innerTarget,
) as InterceptFn["invoke"],
name: innerTargetProp.toString(),
args,
},
{ path: currentPath, fnPath: currentFnPath },
);
if (
typeof result === "object" &&
result !== null &&
!Array.isArray(result)
) {
return createProxy(result, {
path: currentPath,
fnPath: currentFnPath,
});
}
return result;
};
} else if (
typeof innerTargetPropValue === "object" &&
innerTargetPropValue !== null &&
!Array.isArray(innerTargetPropValue)
) {
// wrap nested objects in a proxy as well
return createProxy(innerTargetPropValue, {
path: currentPath,
fnPath,
});
}
return innerTargetPropValue;
},
});
};
return createProxy(db, { path: ["db"] });
};
Glad it's working for you guys! This is a separate thing from the DB-level RLS that the Drizzle team is working on. This solution is app-level RLS.
And yeah, in the initial version transactions were not supported. That's fixed now and I have made some improvements since then that allow for more flexible policies defined at a table level. I have created a new gist if you're interested - https://gist.github.com/ikupenov/26f3775821c05f17b6f8b7a037fb2c7a.
Here's an example:
// schema/entities/example-entity.ts
import { and, eq, isNotNull, or, sql } from "drizzle-orm";
import { pgTable, text, uuid } from "drizzle-orm/pg-core";
import { hasRole } from "@sheetah/common";
import { policy } from "@sheetah/db/orm";
import {
getOptionalOrgOwnedBaseEntityProps,
getOwnedBaseEntityProps,
} from "./base";
export const entities = pgTable("entity", {
...getOwnedBaseEntityProps(),
...getOptionalOrgOwnedBaseEntityProps(),
description: text("description"),
transactionId: uuid("transaction_id").unique().notNull(),
categoryId: uuid("category_id"),
taxRateId: uuid("tax_rate_id"),
});
policy(entities, ({ userId, orgId, role }) => {
return (
or(
userId ? eq(entities.ownerId, userId) : sql`false`,
orgId && hasRole(role, ["org:admin"])
? and(
isNotNull(entities.organizationId),
eq(entities.organizationId, orgId),
)
: sql`false`,
) ?? sql`false`
);
});
export type Entity = typeof expenses.$inferSelect;
FYI, v1 introduces a breaking change for this pattern, since the query builder API is getting revamped, so this solution will need a touch up.
ERR TypeError:
db.select(...).from is not a function
db.insert(...).values is not a function