Skip to content

Commit

Permalink
fix(vectorizer): use format_type to get primary key data types
Browse files Browse the repository at this point in the history
  • Loading branch information
jgpruitt committed Feb 20, 2025
1 parent e81da7d commit 848081d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
2 changes: 1 addition & 1 deletion projects/extension/sql/idempotent/012-vectorizer-int.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ $func$
select pg_catalog.jsonb_agg(x)
from
(
select e.attnum, e.pknum, a.attname, y.typname
select e.attnum, e.pknum, a.attname, pg_catalog.format_type(y.oid, a.atttypmod) as typname
from pg_catalog.pg_constraint k
cross join lateral pg_catalog.unnest(k.conkey) with ordinality e(attnum, pknum)
inner join pg_catalog.pg_attribute a
Expand Down
38 changes: 38 additions & 0 deletions projects/extension/sql/incremental/017-upgrade-source-pk.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

do language plpgsql $block$
declare
_vec ai.vectorizer;
_source pg_catalog.oid;
_source_pk pg_catalog.jsonb;
begin
for _vec in (select * from ai.vectorizer)
loop
_source = pg_catalog.to_regclass(pg_catalog.format('%I.%I', _vec.source_schema, _vec.source_table));
if _source is null then
continue;
end if;

select pg_catalog.jsonb_agg(x) into _source_pk
from
(
select e.attnum, e.pknum, a.attname, pg_catalog.format_type(y.oid, a.atttypmod) as typname
from pg_catalog.pg_constraint k
cross join lateral pg_catalog.unnest(k.conkey) with ordinality e(attnum, pknum)
inner join pg_catalog.pg_attribute a
on (k.conrelid operator(pg_catalog.=) a.attrelid
and e.attnum operator(pg_catalog.=) a.attnum)
inner join pg_catalog.pg_type y on (a.atttypid operator(pg_catalog.=) y.oid)
where k.conrelid operator(pg_catalog.=) _source
and k.contype operator(pg_catalog.=) 'p'
) x;

if _source_pk is null then
continue;
end if;

update ai.vectorizer u set source_pk = _source_pk
where u.id = _vec.id
;
end loop;
end;
$block$;
66 changes: 65 additions & 1 deletion projects/extension/tests/vectorizer/test_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"pknum": 2,
"attnum": 3,
"attname": "published",
"typname": "timestamptz"
"typname": "timestamp with time zone"
}
],
"view_name": "blog_embedding",
Expand Down Expand Up @@ -1815,3 +1815,67 @@ def test_create_vectorizer_privs():
, grant_to=>null
);
""")


def test_weird_primary_key():
# Test multi-column primary keys with "interesting" data types.
# Using format_type() instead of pg_type.typname is important
# because format_type() supports these "interesting" types.
# "Interesting" data types include arrays, ones with multi-word
# names, domains, ones defined in "non-standard" schemas, etc.
# This test also ensures that multi-column primary keys are
# handled correctly in the creation of the queue, trigger,
# target, and view. We also test the usage of the trigger and
# queue in the context of this "weird" primary key
with psycopg.connect(
db_url("test"), autocommit=True, row_factory=namedtuple_row
) as con:
with con.cursor() as cur:
cur.execute("create extension if not exists ai cascade")
cur.execute("create extension if not exists timescaledb")
cur.execute("create schema if not exists vec")
cur.execute("drop domain if exists vec.code cascade")
cur.execute("create domain vec.code as varchar(3)")
cur.execute("drop table if exists vec.weird")
cur.execute("""
create table vec.weird
( a text[] not null
, b vec.code not null
, c timestamp with time zone not null
, d tstzrange not null
, note text not null
, primary key (a, b, c, d)
)
""")

# create a vectorizer for the table
# language=PostgreSQL
cur.execute("""
select ai.create_vectorizer
( 'vec.weird'::regclass
, embedding=>ai.embedding_openai('text-embedding-3-small', 3)
, chunking=>ai.chunking_character_text_splitter('note')
, scheduling=> ai.scheduling_none()
, indexing=>ai.indexing_none()
, grant_to=>ai.grant_to('public')
, enqueue_existing=>false
);
""")
vectorizer_id = cur.fetchone()[0]

# insert 7 rows into the source and see if the trigger works
cur.execute("""
insert into vec.weird(a, b, c, d, note)
select
array['larry', 'moe', 'curly']
, 'xyz'
, t
, tstzrange(t, t + interval '1d', '[)')
, 'if two witches watch two watches, which witch watches which watch'
from generate_series('2025-01-06'::timestamptz, '2025-01-12'::timestamptz, interval '1d') t
""")

# check that the queue has 7 rows
cur.execute("select ai.vectorizer_queue_pending(%s)", (vectorizer_id,))
actual = cur.fetchone()[0]
assert actual == 7

0 comments on commit 848081d

Please sign in to comment.