update vector search to use 'annoy'

This commit is contained in:
Coding with Peter 2023-04-27 08:10:18 -07:00
parent ad6a4366cf
commit 787c3161e3
1 changed files with 60 additions and 41 deletions

View File

@ -1044,38 +1044,6 @@ def txt_clean_index():
def search_embeddings():
model = SentenceTransformer('all-MiniLM-L6-v2')
save_embeds = pickle.load( open( "cache/embeddings.p", "rb" ) )
columns = list(zip(*save_embeds))
files = columns[0]
sentences = columns[1]
embeddings = columns[2]
print(files[:20])
print(sentences[:20])
print(embeddings[:20])
s = ''
while s != 'q':
s = input("search or 'q' to quit: ")
if s == 'q':
return
query_embedding = model.encode(s)
# Compute the cosine similarity between the query embedding and the sentence embeddings
cosine_scores = util.cos_sim(query_embedding, embeddings)
# Sort the sentences by their cosine similarity to the query sentence
results = sorted(zip(sentences, cosine_scores, files), key=lambda x: x[1], reverse=True)
print(results[:5])
# Print the top 5 results
for i, (sentence, score, file) in enumerate(results[:5]):
print(f'Top {i+1}: {file} - {sentence} - (Score: {score})')
from whoosh import fields, columns from whoosh import fields, columns
from whoosh.index import create_in, open_dir from whoosh.index import create_in, open_dir
@ -1207,27 +1175,77 @@ def create_search_index():
writer.commit() writer.commit()
from annoy import AnnoyIndex
import random
def test_embed():
model = SentenceTransformer('all-MiniLM-L6-v2')
sample = "What is this world coming to? What happens in the data and the research?"
embed = model.encode(sample)
print("\nSample sentence:", sample)
print("\nEmbedding:", embed)
print("\nEmbedding size:", len(embed))
def create_embeddings(): def create_embeddings():
model = SentenceTransformer('all-MiniLM-L6-v2') model = SentenceTransformer('all-MiniLM-L6-v2')
vecsize = 384 # sentence transformer embedding size
t = AnnoyIndex(vecsize, 'angular')
files = os.listdir('cache/crawl') files = os.listdir('cache/crawl')
output = [] output = [] # ['index', 'file','sentence']
save_embeds = [] # ['file','sentence','embedding'] index = 0
save_embeds = []
files.sort() files.sort()
for f in files: for f in files:
print(f)
m = re.match(r'https?..www\.gavilan\.edu\+(.*)\.\w\w\w\w?\.txt$',f) m = re.match(r'https?..www\.gavilan\.edu\+(.*)\.\w\w\w\w?\.txt$',f)
if m: if m:
lines = displayfile(f,1) lines = displayfile(f,1)
embeddings = model.encode(lines) embeddings = model.encode(lines)
print("\n-----", f) print("\n-----", index, f)
#Print the embeddings
for sentence, embedding in zip(lines, embeddings): for sentence, embedding in zip(lines, embeddings):
print("Sentence:", sentence) if len(sentence.split(' ')) > 5:
#print("Embedding:", embedding) print(index, "Sentence:", sentence)
print(embedding[:8])
t.add_item(index, embedding)
output.append( [index,f,sentence] )
index += 1
if index > 500:
break
t.build(30) # 30 trees
t.save('cache/sentences.ann')
pickle.dump( output, open( "cache/embedding_index.p", "wb" ) )
save_embeds.append([f,sentence,embedding])
pickle.dump( save_embeds, open( "cache/embeddings.p", "wb" ) )
def search_embeddings():
f = 384 # sentence transformer embedding size
n = 10 # how many results
u = AnnoyIndex(f, 'angular')
u.load('cache/sentences.ann') # super fast, will just mmap the file
print(u.get_n_items(), "items in index")
model = SentenceTransformer('all-MiniLM-L6-v2')
search_index = pickle.load( open( "cache/embedding_index.p", "rb" ) )
print(search_index)
s = ''
while s != 'q':
s = input("search or 'q' to quit: ")
if s == 'q':
return
query_embedding = model.encode(s)
results = u.get_nns_by_vector(query_embedding, n)
# Print the top 5 results
for i, r in enumerate(results):
print(f'Top {i+1}: {r}, {search_index[r]}') #{file} - {sentence} - (Score: {score})')
if __name__ == "__main__": if __name__ == "__main__":
@ -1248,6 +1266,7 @@ if __name__ == "__main__":
13: ['do an index search', search_index], 13: ['do an index search', search_index],
14: ['do a vector search', search_embeddings], 14: ['do a vector search', search_embeddings],
15: ['test priority', test_priority], 15: ['test priority', test_priority],
16: ['test embed', test_embed],
} }
if len(sys.argv) > 1 and re.search(r'^\d+',sys.argv[1]): if len(sys.argv) > 1 and re.search(r'^\d+',sys.argv[1]):