update vector search to use 'annoy'
This commit is contained in:
parent
ad6a4366cf
commit
787c3161e3
99
content.py
99
content.py
|
|
@ -1044,38 +1044,6 @@ def txt_clean_index():
|
|||
|
||||
|
||||
|
||||
def search_embeddings():
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
save_embeds = pickle.load( open( "cache/embeddings.p", "rb" ) )
|
||||
columns = list(zip(*save_embeds))
|
||||
files = columns[0]
|
||||
sentences = columns[1]
|
||||
embeddings = columns[2]
|
||||
|
||||
print(files[:20])
|
||||
print(sentences[:20])
|
||||
print(embeddings[:20])
|
||||
|
||||
s = ''
|
||||
while s != 'q':
|
||||
s = input("search or 'q' to quit: ")
|
||||
if s == 'q':
|
||||
return
|
||||
query_embedding = model.encode(s)
|
||||
|
||||
# Compute the cosine similarity between the query embedding and the sentence embeddings
|
||||
cosine_scores = util.cos_sim(query_embedding, embeddings)
|
||||
|
||||
# Sort the sentences by their cosine similarity to the query sentence
|
||||
results = sorted(zip(sentences, cosine_scores, files), key=lambda x: x[1], reverse=True)
|
||||
|
||||
print(results[:5])
|
||||
|
||||
# Print the top 5 results
|
||||
for i, (sentence, score, file) in enumerate(results[:5]):
|
||||
print(f'Top {i+1}: {file} - {sentence} - (Score: {score})')
|
||||
|
||||
|
||||
|
||||
from whoosh import fields, columns
|
||||
from whoosh.index import create_in, open_dir
|
||||
|
|
@ -1207,27 +1175,77 @@ def create_search_index():
|
|||
writer.commit()
|
||||
|
||||
|
||||
|
||||
from annoy import AnnoyIndex
|
||||
import random
|
||||
|
||||
def test_embed():
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
sample = "What is this world coming to? What happens in the data and the research?"
|
||||
embed = model.encode(sample)
|
||||
|
||||
print("\nSample sentence:", sample)
|
||||
print("\nEmbedding:", embed)
|
||||
print("\nEmbedding size:", len(embed))
|
||||
|
||||
|
||||
def create_embeddings():
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
vecsize = 384 # sentence transformer embedding size
|
||||
t = AnnoyIndex(vecsize, 'angular')
|
||||
files = os.listdir('cache/crawl')
|
||||
output = []
|
||||
save_embeds = [] # ['file','sentence','embedding']
|
||||
output = [] # ['index', 'file','sentence']
|
||||
index = 0
|
||||
save_embeds = []
|
||||
files.sort()
|
||||
for f in files:
|
||||
print(f)
|
||||
m = re.match(r'https?..www\.gavilan\.edu\+(.*)\.\w\w\w\w?\.txt$',f)
|
||||
if m:
|
||||
lines = displayfile(f,1)
|
||||
embeddings = model.encode(lines)
|
||||
|
||||
print("\n-----", f)
|
||||
print("\n-----", index, f)
|
||||
|
||||
#Print the embeddings
|
||||
for sentence, embedding in zip(lines, embeddings):
|
||||
print("Sentence:", sentence)
|
||||
#print("Embedding:", embedding)
|
||||
if len(sentence.split(' ')) > 5:
|
||||
print(index, "Sentence:", sentence)
|
||||
print(embedding[:8])
|
||||
t.add_item(index, embedding)
|
||||
output.append( [index,f,sentence] )
|
||||
index += 1
|
||||
if index > 500:
|
||||
break
|
||||
t.build(30) # 30 trees
|
||||
t.save('cache/sentences.ann')
|
||||
pickle.dump( output, open( "cache/embedding_index.p", "wb" ) )
|
||||
|
||||
save_embeds.append([f,sentence,embedding])
|
||||
pickle.dump( save_embeds, open( "cache/embeddings.p", "wb" ) )
|
||||
|
||||
|
||||
|
||||
def search_embeddings():
|
||||
f = 384 # sentence transformer embedding size
|
||||
n = 10 # how many results
|
||||
|
||||
u = AnnoyIndex(f, 'angular')
|
||||
u.load('cache/sentences.ann') # super fast, will just mmap the file
|
||||
print(u.get_n_items(), "items in index")
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
search_index = pickle.load( open( "cache/embedding_index.p", "rb" ) )
|
||||
print(search_index)
|
||||
|
||||
|
||||
s = ''
|
||||
while s != 'q':
|
||||
s = input("search or 'q' to quit: ")
|
||||
if s == 'q':
|
||||
return
|
||||
query_embedding = model.encode(s)
|
||||
results = u.get_nns_by_vector(query_embedding, n)
|
||||
|
||||
# Print the top 5 results
|
||||
for i, r in enumerate(results):
|
||||
print(f'Top {i+1}: {r}, {search_index[r]}') #{file} - {sentence} - (Score: {score})')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -1248,6 +1266,7 @@ if __name__ == "__main__":
|
|||
13: ['do an index search', search_index],
|
||||
14: ['do a vector search', search_embeddings],
|
||||
15: ['test priority', test_priority],
|
||||
16: ['test embed', test_embed],
|
||||
}
|
||||
|
||||
if len(sys.argv) > 1 and re.search(r'^\d+',sys.argv[1]):
|
||||
|
|
|
|||
Loading…
Reference in New Issue