atextcrawler/src/atextcrawler/tensorflow.py

70 lines
2.1 KiB
Python

"""
Query the tensorflow_model_server's REST API.
"""
import logging
from typing import Optional, Union
import aiohttp
logger = logging.getLogger(__name__)
class TensorFlow:
"""
Fetch an embedding vector from the tensorflow model server.
"""
def __init__(
self,
tf_config,
session: aiohttp.ClientSession,
timeout_sock_connect: Union[int, float] = 0.5,
timeout_sock_read: Union[int, float] = 10,
):
self.config = tf_config
self.session = session
self.timeout = aiohttp.ClientTimeout(
sock_connect=timeout_sock_connect, sock_read=timeout_sock_read
)
async def embed(
self, text: Union[str, list[str]]
) -> Optional[Union[list[float], list[list[float]]]]:
"""
Query the tensorflow_model_server's REST API for a prediction.
Take a string or a list of strings and return an embedding vector
or a list of embedding vectors.
If the request fails or times out, return None.
"""
text_ = text if isinstance(text, list) else [text]
data = {'signature_name': 'serving_default', 'instances': text_}
try:
async with self.session.post(
self.config['model_server_endpoint'],
json=data,
timeout=self.timeout,
) as resp:
try:
res = await resp.json()
if isinstance(text, list):
return res.get('predictions')
else:
return res.get('predictions')[0]
except:
msg = 'Got invalid response from tensorflow'
logger.error(msg)
return None
except Exception as err:
msg = 'Could not get embedding from tensorflow for '
if isinstance(text, str):
msg += f'string of length {len(text)}'
else:
msg += 'list of strings with lengths '
msg += ','.join([str(len(s)) for s in text])
msg += f', reason: {err}'
logger.error(msg)
return None