""" Query the tensorflow_model_server's REST API. """ import logging from typing import Optional, Union import aiohttp logger = logging.getLogger(__name__) class TensorFlow: """ Fetch an embedding vector from the tensorflow model server. """ def __init__( self, tf_config, session: aiohttp.ClientSession, timeout_sock_connect: Union[int, float] = 0.5, timeout_sock_read: Union[int, float] = 10, ): self.config = tf_config self.session = session self.timeout = aiohttp.ClientTimeout( sock_connect=timeout_sock_connect, sock_read=timeout_sock_read ) async def embed( self, text: Union[str, list[str]] ) -> Optional[Union[list[float], list[list[float]]]]: """ Query the tensorflow_model_server's REST API for a prediction. Take a string or a list of strings and return an embedding vector or a list of embedding vectors. If the request fails or times out, return None. """ text_ = text if isinstance(text, list) else [text] data = {'signature_name': 'serving_default', 'instances': text_} try: async with self.session.post( self.config['model_server_endpoint'], json=data, timeout=self.timeout, ) as resp: try: res = await resp.json() if isinstance(text, list): return res.get('predictions') else: return res.get('predictions')[0] except: msg = 'Got invalid response from tensorflow' logger.error(msg) return None except Exception as err: msg = 'Could not get embedding from tensorflow for ' if isinstance(text, str): msg += f'string of length {len(text)}' else: msg += 'list of strings with lengths ' msg += ','.join([str(len(s)) for s in text]) msg += f', reason: {err}' logger.error(msg) return None