Skip to content

Commit

Permalink
feat: Implement PgVector Type
Browse files Browse the repository at this point in the history
  • Loading branch information
28Smiles committed Sep 28, 2024
1 parent ac86987 commit e375754
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
17 changes: 17 additions & 0 deletions src/postgres/def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ pub enum Type {
/// Variable-length multidimensional array
Array(ArrayDef),

#[cfg(feature = "postgres-vector")]
/// The postgres vector type introduced by the vector extension.
Vector(VectorDef),

// TODO:
// /// The structure of a row or record; a list of field names and types
// Composite,
Expand Down Expand Up @@ -268,6 +272,14 @@ pub struct ArrayDef {
pub col_type: Option<RcOrArc<Type>>,
}

#[cfg(feature = "postgres-vector")]
/// Defines an enum for the PostgreSQL module
#[derive(Clone, Debug, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct VectorDef {
pub length: Option<u32>,
}

impl Type {
pub fn has_numeric_attr(&self) -> bool {
matches!(self, Type::Numeric(_) | Type::Decimal(_))
Expand Down Expand Up @@ -302,4 +314,9 @@ impl Type {
pub fn has_array_attr(&self) -> bool {
matches!(self, Type::Array(_))
}

#[cfg(feature = "postgres-vector")]
pub fn has_vector_attr(&self) -> bool {
matches!(self, Type::Vector(_))
}
}
25 changes: 25 additions & 0 deletions src/postgres/parser/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ pub fn parse_column_type(result: &ColumnQueryResult, enums: &EnumVariantMap) ->
if ctype.has_array_attr() {
ctype = parse_array_attributes(result.udt_name_regtype.as_deref(), ctype, enums);
}
#[cfg(feature = "postgres-vector")]
if ctype.has_vector_attr() {
ctype = parse_vector_attributes(result.character_maximum_length, ctype);
}

ctype
}
Expand Down Expand Up @@ -240,3 +244,24 @@ pub fn parse_array_attributes(

ctype
}

#[cfg(feature = "postgres-vector")]
pub fn parse_vector_attributes(
character_maximum_length: Option<i32>,
mut ctype: ColumnType,
) -> ColumnType {
match ctype {
Type::Vector(ref mut attr) => {
attr.length = match character_maximum_length {
None => None,
Some(num) => match u32::try_from(num) {
Ok(num) => Some(num),
Err(_) => None,
},
};
}
_ => panic!("parse_vector_attributes(_) received a type that does not have StringAttr"),
};

ctype
}
21 changes: 6 additions & 15 deletions src/postgres/writer/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,12 @@ impl ColumnInfo {
Type::TsTzRange => ColumnType::Custom(Alias::new("tstzrange").into_iden()),
Type::DateRange => ColumnType::Custom(Alias::new("daterange").into_iden()),
Type::PgLsn => ColumnType::Custom(Alias::new("pg_lsn").into_iden()),
Type::Unknown(s) => {
#[cfg(feature = "postgres-vector")]
if s.starts_with("vector") {
let s = &s[6..];
return if s.starts_with("(") && s.ends_with(")") {
let s = &s[1..s.len() - 1];
let size = s.parse::<u32>().expect("Invalid vector size");
ColumnType::Vector(Some(size))
} else {
ColumnType::Vector(None)
}
};

ColumnType::Custom(Alias::new(s).into_iden())
}
#[cfg(feature = "postgres-vector")]
Type::Vector(vector_attr) => match vector_attr.length {
Some(length) => ColumnType::Vector(Some(length)),
None => ColumnType::Vector(None),
},
Type::Unknown(s) => ColumnType::Custom(Alias::new(s).into_iden()),
Type::Enum(enum_def) => {
let name = Alias::new(&enum_def.typename).into_iden();
let variants: Vec<DynIden> = enum_def
Expand Down

0 comments on commit e375754

Please sign in to comment.