Encode and decode example with T5
In [1]:
pip install -q transformers[sentencepiece]~=4.33.0
In [2]:
import IPython.display as ipd
import torch
torch.__version__
Out[2]:
In [3]:
from transformers import T5Tokenizer, T5EncoderModel, T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("t5-base")
encoder = T5EncoderModel.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base", padding='max_length', truncation=True)
In [6]:
body = "Sea turtles (superfamily Chelonioidea), sometimes called marine turtles,[3] are reptiles of the order Testudines and of the suborder Cryptodira. The seven existing species of sea turtles are the flatback, green, hawksbill, leatherback, loggerhead, Kemp's ridley, and olive ridley sea turtles.[4] All of the seven species listed above, except for the flatback, are present in US waters, and are listed as endangered and/or threatened under the Endangered Species Act.[5] The flatback itself exists in the waters of Australia, Papua New Guinea and Indonesia.[5] Sea turtles can be categorized as hard-shelled (cheloniid) or leathery-shelled (dermochelyid).[6] The only dermochelyid species of sea turtle is the leatherback.[6]"
inputs = [f"summarize: {body}"]
# Encode strings with T5.
encoding = tokenizer(inputs, return_tensors="pt", padding=True)
embeddings = model.encoder(**encoding)
# Perturb embeddings a little bit.
embeddings.last_hidden_state += torch.normal(mean=0.0, std=1e-3, size=embeddings.last_hidden_state.shape)
# Decode same embeddings with T5 back to text.
tokens = model.generate(encoder_outputs=embeddings)
tokenizer.batch_decode(tokens, skip_special_tokens=True)
Out[6]:
Comments
Comments powered by Disqus