support whisper long-form generation (#469)
* fix long form asr accuracy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test input pad issue --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: chen, suyue <suyue.chen@intel.com>
This commit is contained in:
@@ -16,7 +16,7 @@ from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||
class WhisperModel:
|
||||
"""Convert audio to text."""
|
||||
|
||||
def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu"):
|
||||
def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu", hpu_max_len=8192):
|
||||
if device == "hpu":
|
||||
# Explicitly link HPU with Torch
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
@@ -31,12 +31,11 @@ class WhisperModel:
|
||||
self.model.eval()
|
||||
|
||||
self.language = language
|
||||
self.hpu_max_len = hpu_max_len
|
||||
|
||||
if device == "hpu":
|
||||
# do hpu graph warmup with a long enough input audio
|
||||
# whisper has a receptive field of 30 seconds
|
||||
# here we select a relatively long audio (~15 sec) to quickly warmup
|
||||
self._warmup_whisper_hpu_graph("https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav")
|
||||
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_60s_audio.wav")
|
||||
self._warmup_whisper_hpu_graph("https://github.com/Spycsh/assets/raw/main/ljspeech_30s_audio.wav")
|
||||
|
||||
def _audiosegment_to_librosawav(self, audiosegment):
|
||||
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
|
||||
@@ -59,11 +58,54 @@ class WhisperModel:
|
||||
print("[ASR] warmup...")
|
||||
waveform = AudioSegment.from_file("warmup.wav").set_frame_rate(16000)
|
||||
waveform = self._audiosegment_to_librosawav(waveform)
|
||||
# pylint: disable=E1101
|
||||
inputs = self.processor.feature_extractor(
|
||||
waveform, return_tensors="pt", sampling_rate=16_000
|
||||
).input_features.to(self.device)
|
||||
_ = self.model.generate(inputs, language="chinese")
|
||||
|
||||
try:
|
||||
processed_inputs = self.processor(
|
||||
waveform,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="longest",
|
||||
return_attention_mask=True,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "Padding size should be less than" in str(e):
|
||||
# short-form
|
||||
processed_inputs = self.processor(
|
||||
waveform,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16000,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
if processed_inputs.input_features.shape[-1] < 3000:
|
||||
# short-form
|
||||
processed_inputs = self.processor(
|
||||
waveform,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16000,
|
||||
)
|
||||
else:
|
||||
processed_inputs["input_features"] = torch.nn.functional.pad(
|
||||
processed_inputs.input_features,
|
||||
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
|
||||
value=-1.5,
|
||||
)
|
||||
processed_inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
processed_inputs.attention_mask,
|
||||
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
|
||||
value=0,
|
||||
)
|
||||
|
||||
_ = self.model.generate(
|
||||
**(
|
||||
processed_inputs.to(
|
||||
self.device,
|
||||
)
|
||||
),
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
def audio2text(self, audio_path):
|
||||
"""Convert audio to text.
|
||||
@@ -80,11 +122,52 @@ class WhisperModel:
|
||||
audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
|
||||
waveform = audio_dataset[0]["audio"]["array"]
|
||||
|
||||
# pylint: disable=E1101
|
||||
inputs = self.processor.feature_extractor(
|
||||
waveform, return_tensors="pt", sampling_rate=16_000
|
||||
).input_features.to(self.device)
|
||||
predicted_ids = self.model.generate(inputs, language=self.language)
|
||||
try:
|
||||
processed_inputs = self.processor(
|
||||
waveform,
|
||||
return_tensors="pt",
|
||||
truncation=False,
|
||||
padding="longest",
|
||||
return_attention_mask=True,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "Padding size should be less than" in str(e):
|
||||
# short-form
|
||||
processed_inputs = self.processor(
|
||||
waveform,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16000,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
if processed_inputs.input_features.shape[-1] < 3000:
|
||||
# short-form
|
||||
processed_inputs = self.processor(
|
||||
waveform,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16000,
|
||||
)
|
||||
elif self.device == "hpu":
|
||||
processed_inputs["input_features"] = torch.nn.functional.pad(
|
||||
processed_inputs.input_features,
|
||||
(0, self.hpu_max_len - processed_inputs.input_features.size(-1)),
|
||||
value=-1.5,
|
||||
)
|
||||
processed_inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
processed_inputs.attention_mask,
|
||||
(0, self.hpu_max_len + 1 - processed_inputs.attention_mask.size(-1)),
|
||||
value=0,
|
||||
)
|
||||
|
||||
predicted_ids = self.model.generate(
|
||||
**(
|
||||
processed_inputs.to(
|
||||
self.device,
|
||||
)
|
||||
),
|
||||
language=self.language,
|
||||
)
|
||||
# pylint: disable=E1101
|
||||
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
|
||||
if self.language in ["chinese", "mandarin"]:
|
||||
@@ -96,20 +179,23 @@ class WhisperModel:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asr = WhisperModel(language="english")
|
||||
asr = WhisperModel(model_name_or_path="openai/whisper-small", language="english", device="cpu")
|
||||
|
||||
# Test multilanguage asr
|
||||
asr.language = "chinese"
|
||||
urllib.request.urlretrieve(
|
||||
"https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
|
||||
"sample.wav",
|
||||
)
|
||||
asr.language = "chinese"
|
||||
text = asr.audio2text("sample.wav")
|
||||
|
||||
asr.language = "english"
|
||||
urllib.request.urlretrieve(
|
||||
"https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
|
||||
"sample.wav",
|
||||
)
|
||||
text = asr.audio2text("sample.wav")
|
||||
|
||||
os.remove("sample.wav")
|
||||
for i in [5, 10, 30, 60]:
|
||||
urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_{i}s_audio.wav", "sample.wav")
|
||||
text = asr.audio2text("sample.wav")
|
||||
|
||||
Reference in New Issue
Block a user