Skip to content

Commit

Permalink
New embeddings service (#526)
Browse files Browse the repository at this point in the history
* Base implementation of the word embedding search service

* use word match

* updated code after testing

* updated old class to convert words to new format
  • Loading branch information
theorm authored Feb 25, 2025
1 parent 464bb26 commit 3a98559
Show file tree
Hide file tree
Showing 16 changed files with 482 additions and 79 deletions.
8 changes: 8 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
"@types/node": "^22.5.5",
"@types/node-fetch": "^2.5.6",
"@types/sinon": "^17.0.3",
"@types/uuid": "^10.0.0",
"@types/wikidata-sdk": "^5.15.0",
"eslint": "^8.18.0",
"eslint-config-standard": "^17.0.0",
Expand Down
19 changes: 19 additions & 0 deletions src/models/generated/schemas.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2400,6 +2400,25 @@ export interface WikidataEntityDetailsTODOAddPersonLocationSpecificFields {
}


/**
* Represents a word match result from word embeddings similarity search
*/
export interface WordMatch {
/**
* Unique identifier for the word
*/
id: string;
/**
* The language code of the word
*/
languageCode: string;
/**
* The word
*/
word: string;
}


/**
* A year (TODO)
*/
Expand Down
19 changes: 19 additions & 0 deletions src/models/generated/schemasPublic.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,22 @@ export interface WikidataLocation {
longitude?: number;
};
}


/**
* Represents a word match result from word embeddings similarity search
*/
export interface WordMatch {
/**
* Unique identifier for the word
*/
id: string;
/**
* The language code of the word
*/
languageCode: string;
/**
* The word
*/
word: string;
}
22 changes: 22 additions & 0 deletions src/schema/schemas/WordMatch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WordMatch",
"description": "Represents a word match result from word embeddings similarity search",
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for the word"
},
"languageCode": {
"type": "string",
"description": "The language code of the word"
},
"word": {
"type": "string",
"description": "The word"
}
},
"required": ["id", "languageCode", "word"],
"additionalProperties": false
}
22 changes: 22 additions & 0 deletions src/schema/schemasPublic/WordMatch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WordMatch",
"description": "Represents a word match result from word embeddings similarity search",
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for the word"
},
"languageCode": {
"type": "string",
"description": "The language code of the word"
},
"word": {
"type": "string",
"description": "The word"
}
},
"required": ["id", "languageCode", "word"],
"additionalProperties": false
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { escapeValue } from '../../util/solr/filterReducers'
const debug = require('debug')('impresso/services:embeddings')
const { NotFound } = require('@feathersjs/errors')
const { measureTime } = require('../../util/instruments')
const { asFindAll } = require('../../util/solr/adapters')
import { NotFound } from '@feathersjs/errors'
import { measureTime } from '../../util/instruments'
import { asFindAll } from '../../util/solr/adapters'
import debugModule from 'debug'

const debug = debugModule('impresso/services:embeddings')

class Service {
constructor({ app = null, name = '' }) {
Expand All @@ -15,22 +17,22 @@ class Service {
}

async find(params) {
const namespace = `embeddings_${params.query.language}`
const namespace = `embeddings_${params.query.language_code}`
// use en to get embedding vector for the queried word
//
// https:// solrdev.dhlab.epfl.ch/solr/impresso_embeddings_de/select?q=word_s:amour&fl=embedding_bv
debug('[find] with params', params.query)

const bvRequest = {
q: `word_s:(${escapeValue(params.query.q)})`,
q: `word_s:(${escapeValue(params.query.term)})`,
fl: 'embedding_bv',
namespace,
}
const bv = await measureTime(
() =>
asFindAll(this.solr, namespace, bvRequest).then(res => {
if (!res.response.docs.length) {
throw new NotFound(`word "${params.query.q}" not found in available embeddings`)
throw new NotFound(`word "${params.query.term}" not found in available embeddings`)
}
return res.response.docs[0].embedding_bv
}),
Expand Down Expand Up @@ -58,8 +60,8 @@ class Service {
limit: params.query.limit,
offset: params.query.offset,
info: {
q: params.query.q,
language: params.query.language,
q: params.query.term,
language: params.query.language_code,
},
}
}
Expand Down Expand Up @@ -92,8 +94,5 @@ class Service {
}
}

module.exports = function (options) {
return new Service(options)
}

module.exports.Service = Service
export const createService = options => new Service(options)
export { Service }
72 changes: 72 additions & 0 deletions src/services/embeddings/embeddings-v1.hooks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { HookContext } from '@feathersjs/feathers'
import { ImpressoApplication } from '../../types'
import { v4 } from 'uuid'
import { WordMatch } from '../../models/generated/schemas'

const { queryWithCommonParams, validate } = require('../../hooks/params')

export default {
before: {
all: [],
find: [
validate(
{
language_code: {
choices: ['fr', 'de', 'lb'],
},
term: {
required: true,
regex: /^[A-zÀ-ÿ'()\s]+$/,
max_length: 500,
transform: (d: string) =>
d
.replace(/[^A-zÀ-ÿ]/g, ' ')
.toLowerCase()
.split(/\s+/)
.sort((a: string, b: string) => a.length - b.length)
.pop(),
},
},
'GET'
),
queryWithCommonParams(),
],
get: [],
create: [],
update: [],
patch: [],
remove: [],
},

after: {
all: [],
find: [
(context: HookContext<ImpressoApplication>) => {
if (Array.isArray(context.result.data)) {
context.result.data = context.result.data.map((word: string) => {
return {
word,
id: v4(),
languageCode: context.params.query.language ?? 'fr',
} satisfies WordMatch
})
}
},
],
get: [],
create: [],
update: [],
patch: [],
remove: [],
},

error: {
all: [],
find: [],
get: [],
create: [],
update: [],
patch: [],
remove: [],
},
}
118 changes: 118 additions & 0 deletions src/services/embeddings/embeddings.class.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import type { ClientService, Params } from '@feathersjs/feathers'
import { SimpleSolrClient } from '../../internalServices/simpleSolr'
import { PublicFindResponse as FindResponse } from '../../models/common'
import { ImpressoApplication } from '../../types'
import { escapeValue } from '../../util/solr/filterReducers'
import { WordMatch } from '../../models/generated/schemasPublic'

export type ValidLanguageCodes = 'de' | 'fr' | 'lb'

type FindQuery = Pick<FindResponse<unknown>['pagination'], 'limit' | 'offset'> & {
term: string
/** filter baseline vectors search by language. */
language_code?: ValidLanguageCodes
top_k?: number
}

const EmbeddingProperty = 'fastText_emb_v100'

interface SolrEmbeddingsDoc {
word_s: string
[EmbeddingProperty]: number[]
lg_s: string
id: string
}

const asWordMatch = (doc: Omit<SolrEmbeddingsDoc, typeof EmbeddingProperty>): WordMatch => ({
id: doc.id,
word: doc.word_s,
languageCode: doc.lg_s,
})

export const buildGetTermEmbeddingVectorSolrQuery = (term: string, language?: string): string => {
const parts = [`word_s:(${escapeValue(term)})`, language ? `lg_s:${language}` : undefined]
return parts.filter(p => p != null).join(' AND ')
}

export const buildFindBySimilarEmbeddingsSolrQuery = (vectors: number[][], topK: number): string => {
return vectors.map(vector => `({!knn f=${EmbeddingProperty} topK=${topK}}${JSON.stringify(vector)})`).join(' OR ')
}

export const DefaultPageSize = 20
export const DefaultTopK = 20

export class EmbeddingsService
implements Pick<ClientService<WordMatch, unknown, unknown, FindResponse<WordMatch>>, 'find'>
{
private readonly app: ImpressoApplication

constructor({ app }: { app: ImpressoApplication }) {
this.app = app
}

private get solr(): SimpleSolrClient {
return this.app.service('simpleSolrClient')
}

private async getTermEmbeddingVectors(term: string, language?: string): Promise<number[][]> {
const result = await this.solr.select<Pick<SolrEmbeddingsDoc, typeof EmbeddingProperty>>(
this.solr.namespaces.WordEmbeddings,
{
body: {
query: buildGetTermEmbeddingVectorSolrQuery(term, language),
fields: EmbeddingProperty,
limit: 1,
offset: 0,
},
}
)
return result?.response?.docs?.map(item => item[EmbeddingProperty]) ?? []
}

private async getWordsMatchingVectors(
vectors: number[][],
topK: number,
offset: number,
limit: number
): Promise<Omit<SolrEmbeddingsDoc, typeof EmbeddingProperty>[]> {
if (vectors.length === 0) return []
const result = await this.solr.select<Omit<SolrEmbeddingsDoc, typeof EmbeddingProperty>>(
this.solr.namespaces.WordEmbeddings,
{
body: {
query: buildFindBySimilarEmbeddingsSolrQuery(vectors, topK),
fields: ['word_s', 'lg_s', 'id'].join(','),
limit,
offset,
},
}
)
return result?.response?.docs ?? []
}

async find(params?: Params<FindQuery>): Promise<FindResponse<WordMatch>> {
if (!params?.query) {
throw new Error('Query parameters are required')
}

const {
term,
language_code: languageCode,
top_k: topK = DefaultTopK,
limit = DefaultPageSize,
offset = 0,
} = params.query

const vectors = await this.getTermEmbeddingVectors(term, languageCode)
const matches = await this.getWordsMatchingVectors(vectors, topK, offset, limit)

return {
pagination: {
limit,
offset,
total: matches.length,
},
data: matches.map(asWordMatch),
}
}
}
Loading

0 comments on commit 3a98559

Please sign in to comment.