support whisper long-form generation (#469)

* fix long form asr accuracy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test input pad issue

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: chen, suyue <suyue.chen@intel.com>
This commit is contained in:
Sihan Chen
2024-08-14 11:23:07 +08:00
committed by GitHub
parent 7aee7e4689
commit daec6803aa

View File

@@ -16,7 +16,7 @@ from transformers import WhisperForConditionalGeneration, WhisperProcessor
class WhisperModel:
"""Convert audio to text."""
def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu"):
def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu", hpu_max_len=8192):
if device == "hpu":
# Explicitly link HPU with Torch
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@@ -31,12 +31,11 @@ class WhisperModel:
self.model.eval()
self.language = language
self.hpu_max_len = hpu_max_len
if device == "hpu":
# do hpu graph warmup with a long enough input audio
# whisper has a receptive field of 30 seconds
# here we select a relatively long audio (~15 sec) to quickly warmup
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_60s_audio.wav")
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_30s_audio.wav")
def _audiosegment_to_librosawav(self, audiosegment):
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
@@ -59,11 +58,54 @@ class WhisperModel:
print("[ASR] warmup...")
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
waveform = self._audiosegment_to_librosawav(waveform)
# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
_ = self.model.generate(inputs, language="chinese")
try:
processed_inputs = self.processor(
waveform,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16000,
)
except RuntimeError as e:
if "Padding size should be less than" in str(e):
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
raise e
if processed_inputs.input_features.shape[-1] < 3000:
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
processed_inputs["input_features"] = torch.nn.functional.pad(
processed_inputs.input_features,
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
value=-1.5,
)
processed_inputs["attention_mask"] = torch.nn.functional.pad(
processed_inputs.attention_mask,
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
value=0,
)
_ = self.model.generate(
**(
processed_inputs.to(
self.device,
)
),
language=self.language,
)
def audio2text(self, audio_path):
"""Convert audio to text.
@@ -80,11 +122,52 @@ class WhisperModel:
audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
waveform = audio_dataset[0]["audio"]["array"]
# pylint: disable=E1101
inputs = self.processor.feature_extractor(
waveform, return_tensors="pt", sampling_rate=16_000
).input_features.to(self.device)
predicted_ids = self.model.generate(inputs, language=self.language)
try:
processed_inputs = self.processor(
waveform,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16000,
)
except RuntimeError as e:
if "Padding size should be less than" in str(e):
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
else:
raise e
if processed_inputs.input_features.shape[-1] < 3000:
# short-form
processed_inputs = self.processor(
waveform,
return_tensors="pt",
sampling_rate=16000,
)
elif self.device == "hpu":
processed_inputs["input_features"] = torch.nn.functional.pad(
processed_inputs.input_features,
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
value=-1.5,
)
processed_inputs["attention_mask"] = torch.nn.functional.pad(
processed_inputs.attention_mask,
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
value=0,
)
predicted_ids = self.model.generate(
**(
processed_inputs.to(
self.device,
)
),
language=self.language,
)
# pylint: disable=E1101
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
if self.language in ["chinese", "mandarin"]:
@@ -96,20 +179,23 @@ class WhisperModel:
if __name__ == "__main__":
asr = WhisperModel(language="english")
asr = WhisperModel(model_name_or_path="openai/whisper-small", language="english", device="cpu")
# Test multilanguage asr
asr.language = "chinese"
urllib.request.urlretrieve(
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
"sample.wav",
)
asr.language = "chinese"
text = asr.audio2text("sample.wav")
asr.language = "english"
urllib.request.urlretrieve(
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
"sample.wav",
)
text = asr.audio2text("sample.wav")
os.remove("sample.wav")
for i in [5, 10, 30, 60]:
urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_{i}s_audio.wav", "sample.wav")
text = asr.audio2text("sample.wav")