In [23]:
from PIL import Image
import requests

import torch
from torch import nn
from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel
from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from typing import Optional, Union, Tuple

In [43]:
class VisionLanguageConnector(nn.Module):
    def __init__(self, hidden_size, projection_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size, bias=False),
            nn.GELU(),
            nn.Linear(hidden_size, projection_dim, bias=False)
        )

    def forward(self, x):
        return self.mlp(x)
        
class ClipWithProjection():
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, hidden_size, projection_dim):
        super().__init__()
        
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_language_connector = VisionLanguageConnector(hidden_size, projection_dim)

    def forward(
        self,
        image = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPVisionModelOutput]:
        
        pixel_values = self.processor(images=image, return_tensors="pt")["pixel_values"]
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = vision_outputs[1]  # pooled_output

        image_embeds = self.vision_language_connector(pooled_output)

        return CLIPVisionModelOutput(
            image_embeds=image_embeds,
            last_hidden_state=vision_outputs.last_hidden_state,
            hidden_states=vision_outputs.hidden_states,
            attentions=vision_outputs.attentions,
        )

In [44]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

In [46]:
# model = ClipWithProjection(768, 512)
# model.forward(image)

In [47]:
class AudioLanguageConnector:
    def __init__(self, projection_dim):
        model_name = "microsoft/phi-2"
        self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
        self.phi2_tokenizer.max_length = projection_dim

    def __call__(self, text):
        text = f"<audio_start> {text} <audio_end>"
        tokens = self.phi2_tokenizer(text, return_tensors="pt", return_attention_mask=False)
        return tokens
        

class WhisperWithProjection:
    def __init__(self, projection_dim):
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
        self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
        self.model.config.forced_decoder_ids = None
        self.audio_language_connector = AudioLanguageConnector(projection_dim)
        
    def forward(self, audio):
        input_features = self.processor(audio["array"],
                                   sampling_rate=audio["sampling_rate"],
                                   return_tensors="pt").input_features
        # generate token ids
        predicted_ids = self.model.generate(input_features)
        # decode token ids to text        
        transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)

        audio_embeddings = self.audio_language_connector(transcription)
        return audio_embeddings

In [48]:
class TextModality:
    def __init__(self, projection_dim):
        model_name = "microsoft/phi-2"
        self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token
        self.phi2_tokenizer.max_length = projection_dim


    def __call__(self, text):
        tokens = self.phi2_tokenizer(text, return_tensors="pt", return_attention_mask=False)
        return tokens

In [77]:
class MultiModalPhi2:
    def __init__(self):
        self.text_modality = TextModality(projection_dim=768)
        self.whisper_w_proj = WhisperWithProjection(projection_dim=512)
        self.clip_w_proj = ClipWithProjection(hidden_size=768, projection_dim=768)
        self.llm = self.load_llm()

    def load_llm(self):
        model_name = "microsoft/phi-2"
        
        bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16)
    
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            trust_remote_code=True,
            device_map="cuda:0"
        )
        model.config.use_cache = False
        return model

    def forward(self, audio, image, text):
        if text is not None:
            text_embed = self.text_modality(text)["input_ids"]
        if audio is not None:
            audio_embed = self.whisper_w_proj.forward(audio)["input_ids"]
        if image is not None:
            image_embed = self.clip_w_proj.forward(image)[0]
        print(text_embed.shape, text_embed.dtype)
        print(audio_embed.shape, audio_embed.dtype)
        print(image_embed.shape, image_embed.dtype)
        
        inputs = torch.concat([text_embed, audio_embed, image_embed], dim=1)
        print(inputs.shape, inputs.dtype)
        outputs = self.llm(inputs)

        return outputs 
        

    def generate(self, audio, text):
        text_embeddings = self.text_modality(text)
        audio_embeddings = self.whisper_w_proj.forward(audio)
        inputs = torch.concat([text_embed["input_ids"], audio_embed["input_ids"]], dim=1)
        
        outputs = self.llm.generate(inputs, max_length=200)
        text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]
        print(text)

In [74]:
from datasets import load_dataset
audio_ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio = audio_ds[0]["audio"]

In [58]:
text = "explain about the audio"

In [59]:
# image

In [78]:
model = MultiModalPhi2()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [79]:
model.forward(audio, image, text)

torch.Size([1, 5]) torch.int64
torch.Size([1, 33]) torch.int64
torch.Size([1, 768]) torch.float32
torch.Size([1, 806]) torch.float32


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)