# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import tempfile from unittest import TestCase from unittest.mock import patch import numpy as np from datasets import Dataset from transformers import is_faiss_available from transformers.models.bart.configuration_bart import BartConfig from transformers.models.dpr.configuration_dpr import DPRConfig from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer from transformers.models.rag.configuration_rag import RagConfig from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever from transformers.models.roberta.tokenization_roberta import RobertaTokenizer as BartTokenizer from transformers.testing_utils import require_faiss, require_sentencepiece, require_tokenizers, require_torch if is_faiss_available(): import faiss @require_faiss class RagRetrieverTest(TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() self.retrieval_vector_size = 8 # DPR tokenizer vocab self.dpr_vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] self.dpr_vocab = {token: i for i, token in enumerate(self.dpr_vocab_tokens)} # BART tokenizer vocab and merges vocab = [ "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "\u0120", "\u0120l", "\u0120n", "\u0120lo", "\u0120low", "er", "\u0120lowest", "\u0120newer", "\u0120wider", "", ] self.bart_vocab = dict(zip(vocab, range(len(vocab)))) merges_raw = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] self.bart_merges = [] for line in merges_raw: line = line.strip() if line and not line.startswith("#"): self.bart_merges.append(tuple(line.split())) def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer: return DPRQuestionEncoderTokenizer(vocab=self.dpr_vocab) def get_dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer: return DPRContextEncoderTokenizer(vocab=self.dpr_vocab) def get_bart_tokenizer(self) -> BartTokenizer: return BartTokenizer(vocab=self.bart_vocab, merges=self.bart_merges, unk_token="") def tearDown(self): shutil.rmtree(self.tmpdirname) def get_dummy_dataset(self): dataset = Dataset.from_dict( { "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)], } ) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) return dataset def get_dummy_canonical_hf_index_retriever(self): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) return retriever def get_dummy_custom_hf_index_retriever(self, from_disk: bool): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) else: retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), index=CustomHFIndex(config.retrieval_vector_size, dataset), ) return retriever def test_canonical_hf_index_retriever_retrieve(self): n_docs = 1 retriever = self.get_dummy_canonical_hf_index_retriever() hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertEqual(len(doc_dicts), 2) self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) self.assertEqual(len(doc_dicts[0]["id"]), n_docs) self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertListEqual(doc_ids.tolist(), [[1], [0]]) def test_canonical_hf_index_retriever_save_and_from_pretrained(self): retriever = self.get_dummy_canonical_hf_index_retriever() with tempfile.TemporaryDirectory() as tmp_dirname: with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset: mock_load_dataset.return_value = self.get_dummy_dataset() retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None) def test_custom_hf_index_retriever_retrieve(self): n_docs = 1 retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertEqual(len(doc_dicts), 2) self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) self.assertEqual(len(doc_dicts[0]["id"]), n_docs) self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertListEqual(doc_ids.tolist(), [[1], [0]]) def test_custom_hf_index_retriever_save_and_from_pretrained(self): retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False) with tempfile.TemporaryDirectory() as tmp_dirname: retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None) def test_custom_hf_index_retriever_retrieve_from_disk(self): n_docs = 1 retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertEqual(len(doc_dicts), 2) self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) self.assertEqual(len(doc_dicts[0]["id"]), n_docs) self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertListEqual(doc_ids.tolist(), [[1], [0]]) def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self): retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True) with tempfile.TemporaryDirectory() as tmp_dirname: retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None) @require_torch @require_tokenizers @require_sentencepiece def test_hf_index_retriever_call(self): import torch n_docs = 1 retriever = self.get_dummy_canonical_hf_index_retriever() question_input_ids = [[5, 7], [10, 11]] hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) out = retriever(question_input_ids, hidden_states, prefix=None, n_docs=n_docs) context_input_ids, context_attention_mask, retrieved_doc_embeds = ( out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], ) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertIsInstance(context_input_ids, list) self.assertIsInstance(context_attention_mask, list) self.assertIsInstance(retrieved_doc_embeds, np.ndarray) out = retriever( question_input_ids, hidden_states, prefix=None, n_docs=n_docs, return_tensors="pt", ) context_input_ids, context_attention_mask, retrieved_doc_embeds, doc_ids = ( # noqa: F841 out["context_input_ids"], out["context_attention_mask"], out["retrieved_doc_embeds"], out["doc_ids"], ) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertIsInstance(context_input_ids, torch.Tensor) self.assertIsInstance(context_attention_mask, torch.Tensor) self.assertIsInstance(retrieved_doc_embeds, torch.Tensor) @require_torch @require_tokenizers @require_sentencepiece def test_custom_hf_index_end2end_retriever_call(self): context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer() n_docs = 1 retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False) retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer) question_input_ids = [[5, 7], [10, 11]] hidden_states = np.array( [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 ) out = retriever(question_input_ids, hidden_states, prefix=None, n_docs=n_docs) self.assertEqual( len(out), 6 ) # check whether the retriever output consist of 6 attributes including tokenized docs self.assertEqual( all(k in out for k in ("tokenized_doc_ids", "tokenized_doc_attention_mask")), True ) # check for doc token related keys in dictionary.