Translator Lab Solution
Look here for the translator lab solution.
import json
from typing import Dict
import ray
from ray import serve
from starlette.requests import Request
from transformers import pipeline
@serve.deployment
class Translate:
def __init__(self, model: str):
self._model = model
self._pipeline = None
def get_response(self, user_input: str) -> str:
if (self._pipeline is None):
self._pipeline = pipeline(model=self._model)
outputs = self._pipeline(user_input)
return outputs
translate = Translate.bind(model='google/flan-t5-large')
@serve.deployment(ray_actor_options={"runtime_env" : {"pip": ["lingua-language-detector==1.3.2"]}})
class LangDetect:
def __init__(self):
self._detector = None
def get_response(self, user_input: str) -> str:
from lingua import Language, LanguageDetectorBuilder
if (self._detector is None):
languages = [Language.ENGLISH, Language.ITALIAN]
self._detector = LanguageDetectorBuilder.from_languages(*languages).build()
output = self._detector.detect_language_of(user_input)
if (output == Language.ENGLISH):
return 'en'
elif (output == Language.ITALIAN):
return 'it'
else:
raise Exception('Unsupported language')
lang_detect = LangDetect.bind()
@serve.deployment
class Endpoint:
def __init__(self, lang_detect, translate):
self._lang_detect = lang_detect
self._translate = translate
async def __call__(self, request: Request) -> Dict:
data = await request.json()
data = json.loads(data)
return {'response': await self.get_response(data['user_input']) }
async def get_response(self, user_input: str):
lang_obj_ref = await self._lang_detect.get_response.remote(user_input)
lang = await lang_obj_ref
if (lang == 'it'):
prompt = "Translate to English: "
elif (lang == 'en'):
prompt = "Translate to Italian: "
else:
raise Exception('Unsupported language')
result = await self._translate.get_response.remote(prompt + user_input)
response = await result
return response
endpoint = Endpoint.bind(lang_detect, translate)
endpoint_handle = serve.run(endpoint, name = 'translator')
r = endpoint_handle.get_response.remote("I like playing tennis.")
ray.get(r)
ray.get(endpoint_handle.get_response.remote("Mi piace giocare a tennis"))
serve.shutdown()