atextcrawler/src/atextcrawler/utils/similarity.py

93 lines
2.3 KiB
Python

"""
Text similarity with simhash.
"""
import logging
from asyncpg import Connection
from simhash import Simhash, SimhashIndex
logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)
postgresql_bigint_offset = 9223372036854775808
"""
Subtract this number to get a PostgreSQL bigint from a 64bit int.
"""
def get_features(txt: str) -> list[str]:
"""
Extract features from string for use with Simhash.
"""
width = 3
txt = txt.replace(' ', '').lower()
return [txt[i : i + width] for i in range(max(len(txt) - width + 1, 1))]
def simhash_to_bigint(simhash: Simhash) -> int:
"""
Convert a simhash to PostgreSQL's bigint value range.
"""
return simhash.value - postgresql_bigint_offset
def simhash_from_bigint(bigint: int) -> Simhash:
"""
Convert a simhash from PostgreSQL's bigint to a Simhash instance.
"""
return Simhash(bigint + postgresql_bigint_offset, log=logger)
def get_simhash(text: str) -> Simhash:
"""
Return the Simhash of the given text.
"""
return Simhash(get_features(text), log=logger)
async def get_simhash_index(conn: Connection, site_id: int) -> SimhashIndex:
"""
Return a simhash index with hashes of all stored resources of the site.
"""
sql = (
"SELECT r.id, r.simhash FROM site_path sp, resource r"
" WHERE sp.site_id=$1 AND sp.resource_id=r.id"
)
rows = await conn.fetch(sql, site_id)
objs = [
(
str(row['id']),
Simhash(row['simhash'] + postgresql_bigint_offset, log=logger),
)
for row in rows
]
return SimhashIndex(objs, k=3, log=logger)
def create_simhash(
index: SimhashIndex,
resource_id: int,
simhash_instance: Simhash,
) -> int:
"""
Add a resource with given id and simhash to a simhash index.
Return the simhash value shifted into PostgreSQL's bigint range.
(The simhash field of the resource's database entry is not updated.)
"""
index.add(str(resource_id), simhash_instance)
return simhash_to_bigint(simhash_instance)
def search_simhash(index: SimhashIndex, simhash_inst: Simhash) -> list[int]:
"""
Return the ids of similar resources from the index.
"""
found = index.get_near_dups(simhash_inst)
if found:
return sorted([int(elem) for elem in found])
return []