Browse Source

feat: la ponctuation utilise un serveur websocket

master
glorf 3 months ago committed by Guillaume Dupuy
parent
commit
a7536cc973
  1. 1
      .dockerignore
  2. 8
      docker-compose.dev.yml
  3. 5
      docker-compose.yml
  4. 1
      punctuation-server/.dockerignore
  5. 12
      punctuation-server/Dockerfile
  6. 21
      punctuation-server/client.py
  7. 5
      punctuation-server/install_models.py
  8. 21
      punctuation-server/requirements.txt
  9. 25
      punctuation-server/sample.txt
  10. 57
      punctuation-server/server.py
  11. 3
      requirements.txt
  12. 8
      scribe/__init__.py
  13. 37
      scribe/lib.py
  14. 55
      scribe/punctuater.py
  15. 31
      scribe/transcribers.py
  16. 4
      scribe/types.py

1
.dockerignore

@ -0,0 +1 @@
venv

8
docker-compose.dev.yml

@ -9,6 +9,14 @@ services:
command: "flask run"
volumes:
- ./scribe/:/scribe/scribe/ # Bind local code folder for easier dev
dmp-server:
environment:
- DMP_SERVER_PORT=2800
build:
context: .
dockerfile: punctuation-server/Dockerfile
volumes:
- ./punctuation-server/:/punctuation-server/
vosk-en:
image: alpine

5
docker-compose.yml

@ -13,7 +13,12 @@ services:
depends_on:
- vosk-fr
- vosk-en
- dmp-server
vosk-fr:
image: alphacep/kaldi-fr:latest
vosk-en:
image: alphacep/kaldi-en:latest
dmp-server:
image: gloorf/dmp-server:latest
ports:
- 2800:2800

1
punctuation-server/.dockerignore

@ -0,0 +1 @@
venv

12
punctuation-server/Dockerfile

@ -0,0 +1,12 @@
FROM debian:11
RUN apt update && apt install -y python3 python3-pip
COPY punctuation-server/ /punctuation-server/
WORKDIR /punctuation-server/
RUN pip install -r requirements.txt
# the model is downloaded at runtime, so let's make sure it's downloaded now so it's included in the image
RUN python3 install_models.py
EXPOSE 2800
CMD ["python3", "/punctuation-server/server.py"]

21
punctuation-server/client.py

@ -0,0 +1,21 @@
import websockets
import asyncio
import json
import sys
async def send(uri):
with open("sample.txt", "r") as f:
message = f.read()
print(len(message))
async with websockets.connect(uri) as websocket:
await websocket.send(json.dumps({"text": message}))
answer = await websocket.recv()
print(json.loads(answer))
if __name__ == "__main__":
if len(sys.argv) > 1:
uri = sys.argv[1]
else:
uri = "ws://localhost:2800"
asyncio.run(send(uri))

5
punctuation-server/install_models.py

@ -0,0 +1,5 @@
from deepmultilingualpunctuation import PunctuationModel
# This will force pytorch/transformers (not sure which one is responsible for this)
# to download the model (so it's included in our docker image)
model = PunctuationModel()

21
punctuation-server/requirements.txt

@ -0,0 +1,21 @@
certifi==2022.6.15
charset-normalizer==2.0.12
deepmultilingualpunctuation==1.0.1
filelock==3.7.1
huggingface-hub==0.8.1
idna==3.3
numpy==1.23.0
packaging==21.3
protobuf==3.20.1
pyparsing==3.0.9
PyYAML==6.0
regex==2022.6.2
requests==2.28.0
sentencepiece==0.1.96
tokenizers==0.12.1
torch==1.11.0
tqdm==4.64.0
transformers[torch]==4.20.1
typing_extensions==4.2.0
urllib3==1.26.9
websockets==10.3

25
punctuation-server/sample.txt

@ -0,0 +1,25 @@
ça y est
la gauche s'unit enfin de son côté emmanuel macron avait dit que cette victoire l'obligeait à gauche aussi ont tirer les leçons mère
c'est plutôt la défaite qui les oblige
bon
bon
en discutant ils se sont très vite rendu compte qu'ils faisaient tous partie de la même famille politique comme dans toutes les familles il y a des sujets tabous avant tout vous voulez qu'on commence par envoyer un message clair à vos électeurs sur votre positionnement vis-à-vis de l'europe
déjà
puisse
du donbass à faut-il accentuer l'aide
macron il est pas un peu
ça ça va casser
sûr que je ne suis pas tombé sur le dossier le plus simple c'est très fastidieux très compliqué ça change de ma dernière négo quand zemmour avait tenté de faire l'union de l'extrême droite là le job idiot
trois minutes le dîner est arrivé
alors conformément aux demandes nous avons donc des pizzas ananas tofu anchois agneau raclette une de chaque mais non tout mélanger
ah bah oui quand personne n'a fait des concessions ça fait forcément résultat dégueulasse que ça vous serve de leçon pour ce soir
un appétit il est vrai qu'on signe souvent les accords tard dans la nuit parce que il y a de nombreux points de blocage et aussi parce que tant que mélenchon s'endorme donc pour le nom on est d'accord qu'on partout sur nouvelle union populaire écologique et sociale et démocrate européenne mais souverainiste maintenant non non non non non non
d'accord c'est déjà pas mal
toutes façons faites gaffe en plus le nom est long plus les gens vont voir vous n'êtes pas d'accord alors il y a des parties simples à convaincre et d'autres avec qui c'est plus compliqué
p s pour qui rejoindre une union écolo passe forcément par un gros massacre d'éléphants je n'ai pas tout le monde serrez-vous un peu sur les côtés
plus à droite on va pas recommencer ok vous vous prenez pas la tête la façon à la fin on va mettre un grand portrait de mélenchon devant
en tout cas on assiste à un événement historique il y a encore un mois qui pensait que europe écologie les verts s'associera avec un parti qui est prêt à remettre en cause les règles européennes qu'un leader communiste pro-nucléaire reviendrait sur sa position en tout cas on est tous réunis aujourd'hui autour d'une même conviction c'est que pour s'approcher du pouvoir
pour accepter sa force de conviction
en route
broute haute

57
punctuation-server/server.py

@ -0,0 +1,57 @@
#!/usr/bin/env python3
# Inspired by https://github.com/alphacep/vosk-server/blob/master/websocket/asr_server.py
import json
import os
import asyncio
from dataclasses import dataclass
import websockets
import logging
from deepmultilingualpunctuation import PunctuationModel
@dataclass
class PunctuationServerArgs:
interface: str
port: int
async def punctuate(websocket):
# Very ugly !
global model
logging.debug(f"Connection from {websocket.remote_address}")
async for message in websocket:
try:
data = json.loads(message)
assert "text" in data
clean_text = model.preprocess(data["text"])
# DMP returns float32 instead of float, and float32 aren't JSON-serializable
labeled_words = [[x[0], x[1], float(x[2])] for x in model.predict(clean_text)]
await websocket.send(json.dumps({"predictions": labeled_words}))
except json.JSONDecodeError:
logging.warn(f"Could not decode '{message}'")
async def start():
global model
# Enable loging if needed
#
# logger = logging.getLogger('websockets')
# logger.setLevel(logging.INFO)
# logger.addHandler(logging.StreamHandler())
logging.basicConfig(level=logging.INFO)
args = PunctuationServerArgs(interface=os.environ.get('DMP_SERVER_INTERFACE', '0.0.0.0'),
port=int(os.environ.get('DMP_SERVER_PORT', 2800)))
model = PunctuationModel()
async with websockets.serve(punctuate, args.interface, args.port):
await asyncio.Future()
if __name__ == '__main__':
asyncio.run(start())

3
requirements.txt

@ -24,6 +24,3 @@ Werkzeug==2.1.2
youtube-dl==2021.12.17
yt-dlp==2022.4.8
zipp==3.8.0
protobuf==3.20.1
sentencepiece==0.1.96
deepmultilingualpunctuation==1.0.0

8
scribe/__init__.py

@ -15,6 +15,7 @@ from .lib import get_file, get_url, get_mail, get_lang, get_translation_lang
from .utils import ALLOWED_EXTENSIONS, error_wrapper, get_ip, get_translation_langs
from .senders import MailSender, LocalSender, VirtualSender
from .translater import LibreTranslater
from .punctuater import DMPPunctuater
from .transcribers import FileTranscriber, YoutubeTranscriber, VirtualTranscriber
@ -115,13 +116,14 @@ def start() -> Union[werkzeug.wrappers.response.Response, str]: # noqa: C901
lang.lower(), translation_lang)
else:
translater = None
punctuater = DMPPunctuater("ws://dmp-server:2800")
# Now let's create our transcriber
if isinstance(file, FileStorage):
transcriber: VirtualTranscriber = FileTranscriber(
current_app.logger, sender, translater, vosk_uri, timings_path, logged_mail, ip, file)
transcriber: VirtualTranscriber = FileTranscriber(current_app.logger, sender, translater,
punctuater, vosk_uri, timings_path, logged_mail, ip, file)
display_name = file.filename
elif isinstance(url, str):
transcriber = YoutubeTranscriber(current_app.logger, sender, translater,
transcriber = YoutubeTranscriber(current_app.logger, sender, translater, punctuater,
vosk_uri, timings_path, logged_mail, ip, url,
current_app.config['FFMPEG_LOCATION'])
display_name = url

37
scribe/lib.py

@ -5,8 +5,6 @@ from pathlib import Path
from typing import Union, cast, Sequence
import contextlib
from deepmultilingualpunctuation import PunctuationModel
import flask.wrappers
import websockets.client
import validators # type: ignore
@ -99,7 +97,7 @@ async def send_to_vosk(uri: str, audio_path: Path) -> Sequence[Union[VoskPartial
return results
def filter_success(results: Sequence[Union[VoskPartial, VoskPhrase]]) -> Sequence[VoskPhrase]:
def filter_vosk_success(results: Sequence[Union[VoskPartial, VoskPhrase]]) -> Sequence[VoskPhrase]:
"""
Filter the partial result out.
We could do it directly in send_to_vosk, but we might need partial results, one day
@ -107,44 +105,11 @@ def filter_success(results: Sequence[Union[VoskPartial, VoskPhrase]]) -> Sequenc
return [cast(VoskPhrase, x) for x in results if "result" in x]
# load the punctuation model only once
_punctuationmodel = PunctuationModel()
def __punctuation_to_text(prediction):
"""
Change a prediction result array from deepmultilingualpunctuation to a end-user text string
"""
result = ""
capnext = True
for word, label, _ in prediction:
if capnext:
word = word.capitalize()
capnext = False
result += word
if label == "0":
result += " "
if label in ".,?-:":
result += label + " "
if label == ".":
capnext = True
result = result.strip()
result.capitalize()
return result
def to_phrases(results: Sequence[VoskPhrase], with_timing: bool = False) -> str:
"""
Return all phrases from audio
Optionally, add timings (in the form [HH'MM'SS])
"""
global _punctuationmodel
# Restore punctuation
for x in results:
clean_text = _punctuationmodel.preprocess(x["text"])
labeled_words = _punctuationmodel.predict(clean_text)
x["text"] = __punctuation_to_text(labeled_words)
# Very simple, one phrase = one row
if not with_timing:

55
scribe/punctuater.py

@ -0,0 +1,55 @@
from typing import Sequence
from dataclasses import dataclass
import websockets.client
import json
from .types import VoskPhrase, VoskWord
@dataclass
class DMPPunctuater:
uri: str
def punctuate_phrase(self, phrase: VoskPhrase, prediction: Sequence) -> VoskPhrase:
"""
Update a vosk phrase with a prediction result array from deepmultilingualpunctuation
e.g punctuation is inserted in each result[i]['word'] ; 'Berkeley becomes 'Berkeley,' and so on
"""
words = []
capnext = True
if len(phrase["result"]) != len(prediction):
raise RuntimeError("len(phrase['result']) != len(prediction)")
for old_word, (word, label, _) in zip(phrase["result"], prediction):
if capnext:
word = word.capitalize()
capnext = False
if label in ".,?-:":
word += label
if label == ".":
capnext = True
v_word: VoskWord = {"conf": old_word["conf"],
"start": old_word["start"],
"end": old_word["end"],
"word": word}
words.append(v_word)
final_text = " ".join([x["word"] for x in words])
new_phrase: VoskPhrase = {"result": words, "text": final_text}
return new_phrase
async def punctuate(self, vosk_phrases: Sequence[VoskPhrase]) -> Sequence[VoskPhrase]:
"""
Sends the phrases to deepmultilingual server to punctuate them
Punctuation is inserted in each result[i]['word'] - see to_subtitles for more information about the structure
"""
out = []
for phrase in vosk_phrases:
# websocket are relatively cheap, so we can make one connection per phrase
# It's not very optimised, but ... this project use three differents ML models, are we really picky
# about a little bit of unefficient code ? :P
async with websockets.client.connect(self.uri) as websocket:
await websocket.send(json.dumps({"text": phrase["text"]}))
dmp_result = await websocket.recv()
predictions = json.loads(dmp_result)["predictions"]
out.append(self.punctuate_phrase(phrase, predictions))
return out

31
scribe/transcribers.py

@ -24,7 +24,8 @@ from .types import VoskUri, Timings, MailAddress, VoskPhrase
from .utils import create_zip
from .senders import VirtualSender
from .translater import LibreTranslater
from .lib import filter_success, to_subtitles, to_phrases, send_to_vosk
from .punctuater import DMPPunctuater
from .lib import filter_vosk_success, to_subtitles, to_phrases, send_to_vosk
class FileNameCollectorPP(yt_dlp.postprocessor.common.PostProcessor):
@ -45,6 +46,7 @@ class FileNameCollectorPP(yt_dlp.postprocessor.common.PostProcessor):
class VirtualTranscriber(metaclass=ABCMeta):
sender: VirtualSender
translater: Optional[LibreTranslater]
punctuater: Optional[DMPPunctuater]
vosk_uri: VoskUri
timings: Timings
timings_path: Path
@ -52,10 +54,12 @@ class VirtualTranscriber(metaclass=ABCMeta):
ip: str
def __init__(self, logger: logging.Logger, sender: VirtualSender, translater: Optional[LibreTranslater],
vosk_uri: VoskUri, timings_path: Path, mail: Optional[MailAddress], ip: str) -> None:
punctuater: Optional[DMPPunctuater], vosk_uri: VoskUri, timings_path: Path,
mail: Optional[MailAddress], ip: str) -> None:
self.logger = logger
self.sender = sender
self.translater = translater
self.punctuater = punctuater
self.vosk_uri = vosk_uri
self.timings = Timings()
self.timings_path = timings_path
@ -92,8 +96,16 @@ class VirtualTranscriber(metaclass=ABCMeta):
def transcribe_audio(self, audio: Path, vosk_uri: VoskUri) -> BytesIO:
# Let's send data to vosk
success = filter_success(asyncio.run(send_to_vosk(vosk_uri, audio)))
# Let's process resutls into something usable
success = filter_vosk_success(asyncio.run(send_to_vosk(vosk_uri, audio)))
if self.punctuater is not None:
# We don't want to fail transcription if only punctuation failed, let's log the error
# and move on
try:
success = asyncio.run(self.punctuater.punctuate(success))
except Exception:
self.logger.error("Punctuation error :")
self.logger.error(traceback.format_exc())
# Let's process results into something usable
archive = self.generate_archive(success)
return archive
@ -153,8 +165,9 @@ class FileTranscriber(VirtualTranscriber):
file: FileStorage
def __init__(self, logger: logging.Logger, sender: VirtualSender, translater: Optional[LibreTranslater],
vosk_uri: VoskUri, timings_path: Path, mail: Optional[MailAddress], ip: str, file: FileStorage) -> None:
super().__init__(logger, sender, translater, vosk_uri, timings_path, mail, ip)
punctuater: Optional[DMPPunctuater], vosk_uri: VoskUri, timings_path: Path,
mail: Optional[MailAddress], ip: str, file: FileStorage) -> None:
super().__init__(logger, sender, translater, punctuater, vosk_uri, timings_path, mail, ip)
self.file = file
# MyPy cannot infer this
assert self.file.filename is not None
@ -172,9 +185,9 @@ class YoutubeTranscriber(VirtualTranscriber):
url: str
def __init__(self, logger: logging.Logger, sender: VirtualSender, translater: Optional[LibreTranslater],
vosk_uri: VoskUri, timings_path: Path, mail: Optional[MailAddress], ip: str,
url: str, ffmpeg_location: str) -> None:
super().__init__(logger, sender, translater, vosk_uri, timings_path, mail, ip)
punctuater: Optional[DMPPunctuater], vosk_uri: VoskUri, timings_path: Path,
mail: Optional[MailAddress], ip: str, url: str, ffmpeg_location: str) -> None:
super().__init__(logger, sender, translater, punctuater, vosk_uri, timings_path, mail, ip)
self.url = url
self.sender.set_display_name(self.url)
self.ffmpeg_location = ffmpeg_location

4
scribe/types.py

@ -1,4 +1,4 @@
from typing import NewType, TypedDict, TypeVar, Callable, Any, Sequence
from typing import NewType, TypedDict, TypeVar, Callable, Any, List
from dataclasses import dataclass
from enum import Enum
import datetime
@ -36,7 +36,7 @@ class VoskPartial(TypedDict):
class VoskPhrase(TypedDict):
result: Sequence[VoskWord]
result: List[VoskWord]
text: str

Loading…
Cancel
Save