Most probably the code snippet you re looking at came from https://github.com/facebookresearch/fairseq/blob/main/examples/wmt19/README.md
The example code there looks like this:
import torch
# English to German translation
en2de = torch.hub.load( pytorch/fairseq , transformer.wmt19.en-de ,
checkpoint_file= model1.pt:model2.pt:model3.pt:model4.pt ,
tokenizer= moses , bpe= fastbpe )
But most probably you ll meet some environmental setup issues because fairseq isn t easily useable "off-the-shelf". So, you ll have to do something like this:
! pip install -U fastBPE sacremoses
! pip install -U hydra-core omegaconf bitarray
! git clone https://github.com/pytorch/fairseq && cd fairseq && pip install --editable ./
After setting up the environment, now you can try this again:
import torch
# English to German translation
en2de = torch.hub.load( pytorch/fairseq , transformer.wmt19.en-de ,
checkpoint_file= model1.pt ,
tokenizer= moses , bpe= fastbpe )
type(en2de)
[out]:
fairseq.hub_utils.GeneratorHubInterface
If we do some code digging, it points to https://github.com/facebookresearch/fairseq/blob/main/fairseq/hub_utils.py#L97
class GeneratorHubInterface(nn.Module):
"""
PyTorch Hub interface for generating sequences from a pre-trained
translation or language model.
"""
And if we look at the translate()
function, it goes to https://github.com/facebookresearch/fairseq/blob/main/fairseq/hub_utils.py#LL133C1-L145C76
def translate(
self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
) -> List[str]:
return self.sample(sentences, beam, verbose, **kwargs)
def sample(
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
So .translate()
returns a list of strings
And if we dig deeper into the rabbit hole, we see the .generate()
function from https://github.com/facebookresearch/fairseq/blob/main/fairseq/hub_utils.py#L170 which returns
def generate(
self,
tokenized_sentences: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
prefix_allowed_tokens_fn=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
And if you use the model with .generate()
,
tokenized_sentences = en2de.encode("Machine learning is great!")
en2de.generate([tokenized_sentences])
[out]:
[[{ tokens : tensor([21259, 99, 4125, 15336, 34, 5013, 19663, 111, 2]),
score : tensor(-0.2017),
attention : tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0026, 0.0049, 0.0067],
[0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0021, 0.0029, 0.0034],
[0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0031, 0.0072, 0.0059],
[0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0093, 0.0214, 0.0088],
[0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0274, 0.0126, 0.0065],
[0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.0146, 0.1026, 0.0346],
[0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.9409, 0.8484, 0.9341]]),
alignment : tensor([]),
positional_scores : tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -0.5898, -0.0818, -0.1108,
-0.1069])},
{ tokens : tensor([21259, 99, 4125, 15336, 34, 19503, 111, 2]),
score : tensor(-0.3501),
attention : tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0048, 0.0073],
[0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0027, 0.0037],
[0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0070, 0.0063],
[0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0217, 0.0097],
[0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0129, 0.0078],
[0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.1076, 0.0373],
[0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.8431, 0.9280]]),
alignment : tensor([]),
positional_scores : tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -1.6566, -0.1113, -0.1072])},
{ tokens : tensor([ 5725, 372, 8984, 3845, 34, 5013, 19663, 111, 2]),
score : tensor(-0.4066),
attention : tensor([[0.2876, 0.0278, 0.0040, 0.0030, 0.0192, 0.0150, 0.0032, 0.0075, 0.0083],
[0.1374, 0.0755, 0.0019, 0.0379, 0.0087, 0.0062, 0.0027, 0.0044, 0.0043],
[0.0817, 0.0269, 0.4516, 0.0801, 0.0227, 0.0120, 0.0038, 0.0084, 0.0065],
[0.0684, 0.0034, 0.0067, 0.0091, 0.1939, 0.1039, 0.0097, 0.0224, 0.0099],
[0.0862, 0.0031, 0.0040, 0.0030, 0.1022, 0.2868, 0.0296, 0.0135, 0.0073],
[0.0415, 0.0058, 0.0054, 0.0066, 0.0373, 0.0400, 0.0146, 0.1016, 0.0351],
[0.2972, 0.8574, 0.5264, 0.8603, 0.6160, 0.5361, 0.9364, 0.8422, 0.9287]]),
alignment : tensor([]),
positional_scores : tensor([-2.0029, -0.3431, -0.1785, -0.0286, -0.1586, -0.6527, -0.0782, -0.1101,
-0.1071])},
{ tokens : tensor([21259, 99, 4125, 15336, 34, 8404, 111, 2]),
score : tensor(-0.5465),
attention : tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0047, 0.0074],
[0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0026, 0.0037],
[0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0071, 0.0064],
[0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0221, 0.0095],
[0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0125, 0.0077],
[0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.1046, 0.0372],
[0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.8464, 0.9282]]),
alignment : tensor([]),
positional_scores : tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -3.2290, -0.1100, -0.1075])},
{ tokens : tensor([ 9467, 5293, 34, 5013, 19663, 111, 2]),
score : tensor(-0.5483),
attention : tensor([[0.2876, 0.0109, 0.0157, 0.0154, 0.0032, 0.0069, 0.0093],
[0.1374, 0.0110, 0.0081, 0.0065, 0.0027, 0.0039, 0.0047],
[0.0817, 0.2288, 0.0219, 0.0131, 0.0034, 0.0076, 0.0069],
[0.0684, 0.0045, 0.1818, 0.0989, 0.0097, 0.0224, 0.0100],
[0.0862, 0.0037, 0.0979, 0.2854, 0.0276, 0.0135, 0.0076],
[0.0415, 0.0074, 0.0343, 0.0404, 0.0146, 0.1005, 0.0363],
[0.2972, 0.7337, 0.6403, 0.5402, 0.9388, 0.8452, 0.9253]]),
alignment : tensor([]),
positional_scores : tensor([-2.1557, -0.5372, -0.1502, -0.6968, -0.0819, -0.1092, -0.1072])}]]
.generate()
returns a list of list of dict, where keys are names and values are tensor
The outer most list is the sentences result. If you have one sentence, the result for the sentence is:
tokenized_sentences = [en2de.encode("Machine learning is great!")]
results = en2de.generate(tokenized_sentences)
translation_sent1 = results[0]
len(translation_sent1)
[out]:
5
You ll see that each sentence has 5 translation results. This is because the beam size is set to 5 by default. Each dictionary in the inner list corresponds to the translations from each beam.
tokenized_sentences = [en2de.encode("Machine learning is great!")]
results = en2de.generate(tokenized_sentences, beam=2)
translation_sent1 = results[0]
len(translation_sent1)
[out]:
2
And to get the best translation:
tokenized_sentences = [en2de.encode("Machine learning is great!")]
results = en2de.generate(tokenized_sentences, beam=2)
translation_sent1 = results[0] # 2 translations from 2 beams for the 1st sentence.
best_translation = translation_sent1[0] # Best 1 translation out of the 2 beams.
best_translation
[out]:
{ tokens : tensor([21259, 99, 4125, 15336, 34, 5013, 19663, 111, 2]),
score : tensor(-0.2017),
attention : tensor([[0.2876, 0.0079, 0.0066, 0.0211, 0.0117, 0.0107, 0.0026, 0.0049, 0.0067],
[0.1374, 0.0239, 0.0076, 0.0090, 0.0062, 0.0049, 0.0021, 0.0029, 0.0034],
[0.0817, 0.0073, 0.0472, 0.3804, 0.0206, 0.0112, 0.0031, 0.0072, 0.0059],
[0.0684, 0.0017, 0.0033, 0.0079, 0.1894, 0.1042, 0.0093, 0.0214, 0.0088],
[0.0862, 0.0021, 0.0021, 0.0055, 0.0991, 0.2868, 0.0274, 0.0126, 0.0065],
[0.0415, 0.0053, 0.0049, 0.0089, 0.0388, 0.0405, 0.0146, 0.1026, 0.0346],
[0.2972, 0.9517, 0.9284, 0.5673, 0.6342, 0.5417, 0.9409, 0.8484, 0.9341]]),
alignment : tensor([]),
positional_scores : tensor([-0.5091, -0.0979, -0.0993, -0.0672, -0.1520, -0.5898, -0.0818, -0.1108,
-0.1069])}
And to get the string representation, we fetch the tokens and decode them:
en2de.decode(best_translation[ tokens ])
[out]:
Maschinelles Lernen ist großartig!
Here s the working code for the above examples, https://www.kaggle.com/alvations/how-to-use-fairseq-wmt19-models