2025-01-22 14:47:50 +08:00
|
|
|
import math
|
|
|
|
|
import torch
|
2025-05-08 12:39:44 +08:00
|
|
|
import torch.nn.functional as F
|
2025-01-22 14:47:50 +08:00
|
|
|
import numpy as np
|
2025-05-23 15:21:41 +08:00
|
|
|
from torch import nn
|
2025-01-22 14:47:50 +08:00
|
|
|
|
|
|
|
|
from .transformers import VisionRotary, Decoder
|
2025-05-08 12:39:44 +08:00
|
|
|
from .spinner import spinner_run
|
2025-01-22 14:47:50 +08:00
|
|
|
|
|
|
|
|
class Vision(torch.nn.Module):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__()
|
2025-05-23 15:21:41 +08:00
|
|
|
self.quant_bit = 8
|
|
|
|
|
self.quant_block = 128
|
|
|
|
|
self.transformer_fuse = True
|
2025-01-22 14:47:50 +08:00
|
|
|
self.model_type = base.model_type
|
|
|
|
|
self.visual = visual.eval()
|
|
|
|
|
self.embed_ = base.embed
|
|
|
|
|
self.tokenizer = base.tokenizer
|
|
|
|
|
self.config = base.config
|
|
|
|
|
self.hidden_size = base.hidden_size
|
|
|
|
|
self.llm_config = base.llm_config
|
2025-03-21 21:10:42 +08:00
|
|
|
self.rope_ratio = 1.0
|
2025-01-22 14:47:50 +08:00
|
|
|
# mllama
|
|
|
|
|
self.cross_attention_states = None
|
|
|
|
|
self.cross_attention_mask = None
|
|
|
|
|
self.init_config()
|
|
|
|
|
self.load()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_vision(model_type):
|
|
|
|
|
visual_models = {
|
2025-05-23 15:21:41 +08:00
|
|
|
'deepseek-vl': DeepSeekVL,
|
2025-04-28 11:38:44 +08:00
|
|
|
'internvl_chat': InternVLVision,
|
2025-01-22 14:47:50 +08:00
|
|
|
'qwen': QwenVision,
|
|
|
|
|
'qwen2_vl': Qwen2Vision,
|
2025-05-08 12:39:44 +08:00
|
|
|
'qwen2_5_vl':Qwen2_5Vision,
|
|
|
|
|
'qwen2_5_omni': Qwen2_5OmniVision,
|
2025-05-23 15:21:41 +08:00
|
|
|
'mllama': MllamaVision,
|
|
|
|
|
'gemma3': Gemma3Vision,
|
|
|
|
|
'idefics3': Idefics3Vision,
|
|
|
|
|
'smolvlm': Idefics3Vision,
|
|
|
|
|
'llava_qwen2': MobileCLIPVision
|
2025-01-22 14:47:50 +08:00
|
|
|
}
|
|
|
|
|
if model_type in visual_models:
|
|
|
|
|
return visual_models[model_type]
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def init_config(self):
|
|
|
|
|
from transformers.image_utils import (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
image_mean = np.array(OPENAI_CLIP_MEAN) * 255.0
|
|
|
|
|
image_norm = 1 / (np.array(OPENAI_CLIP_STD) * 255.0)
|
|
|
|
|
self.llm_config['image_mean'] = image_mean.tolist()
|
|
|
|
|
self.llm_config['image_norm'] = image_norm.tolist()
|
|
|
|
|
|
|
|
|
|
def export(self, onnx_path):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
|
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2025-05-23 15:21:41 +08:00
|
|
|
class DeepSeekVL(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.quant_bit = 8
|
|
|
|
|
self.aligner = base.model.aligner
|
|
|
|
|
self.vision_model = visual
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.image_size = 1024
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
|
# self.llm_config['vision_start'] = self.tokenizer.img_start_id
|
|
|
|
|
# self.llm_config['vision_end'] = self.tokenizer.img_end_id
|
|
|
|
|
# self.llm_config['image_pad'] = self.tokenizer.img_pad_id
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
IMAGENET_MEAN = [0.0, 0.0, 0.0]
|
|
|
|
|
IMAGENET_STD = [1.0, 1.0, 1.0]
|
|
|
|
|
for i in range(3):
|
|
|
|
|
IMAGENET_MEAN[i] = IMAGENET_MEAN[i] * 255.0
|
|
|
|
|
IMAGENET_STD[i] = 1.0 / IMAGENET_STD[i] / 255.0
|
|
|
|
|
self.llm_config['image_mean'] = IMAGENET_MEAN
|
|
|
|
|
self.llm_config['image_norm'] = IMAGENET_STD
|
|
|
|
|
self.llm_config['image_size_unit'] = 14
|
|
|
|
|
def export(self, onnx_path):
|
|
|
|
|
input_images = torch.randn((1, 3, self.image_size, self.image_size), dtype=torch.float32)
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (input_images),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['input_images'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"input_images": { 0: "size", 2: "height", 3: "width"},
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
def forward(self, images):
|
|
|
|
|
vit_embeds = self.aligner(self.vision_model(images))
|
|
|
|
|
# For mnn's embedding, the order is (seq, batch, hidden)
|
|
|
|
|
vit_embeds = vit_embeds.permute(1, 0, 2)
|
|
|
|
|
return vit_embeds
|
|
|
|
|
|
|
|
|
|
|
2025-04-28 11:38:44 +08:00
|
|
|
class InternVLVision(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.quant_bit = 8
|
|
|
|
|
self.vision_model = visual
|
|
|
|
|
self.mlp1 = base.model.mlp1
|
|
|
|
|
self.select_layer = base.model.select_layer
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.image_size = self.config.force_image_size
|
|
|
|
|
self.downsample_ratio = self.config.downsample_ratio
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
|
# self.llm_config['vision_start'] = self.tokenizer.img_start_id
|
|
|
|
|
# self.llm_config['vision_end'] = self.tokenizer.img_end_id
|
|
|
|
|
# self.llm_config['image_pad'] = self.tokenizer.img_pad_id
|
|
|
|
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
|
|
|
|
n, w, h, c = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
|
|
|
|
|
# N, W, H, C --> N, W, H * scale, C // scale
|
|
|
|
|
x = x.view(n, w, (h * scale_factor).int(), (c / scale_factor).int())
|
|
|
|
|
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
|
|
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
|
|
|
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
|
|
|
|
x = x.view(n, (h * scale_factor).int(), (w * scale_factor).int(),
|
|
|
|
|
(c / (scale_factor * scale_factor)).int())
|
|
|
|
|
x = x.permute(0, 2, 1, 3).contiguous()
|
|
|
|
|
return x
|
|
|
|
|
def extract_feature(self, pixel_values):
|
|
|
|
|
if self.select_layer == -1:
|
|
|
|
|
vit_embeds = self.vision_model(
|
|
|
|
|
pixel_values=pixel_values,
|
|
|
|
|
output_hidden_states=False,
|
|
|
|
|
return_dict=True).last_hidden_state
|
|
|
|
|
else:
|
|
|
|
|
vit_embeds = self.vision_model(
|
|
|
|
|
pixel_values=pixel_values,
|
|
|
|
|
output_hidden_states=True,
|
|
|
|
|
return_dict=True).hidden_states[self.select_layer]
|
|
|
|
|
vit_embeds = vit_embeds[:, 1:, :]
|
|
|
|
|
|
|
|
|
|
h = w = (vit_embeds.shape[1] ** 0.5).int()
|
|
|
|
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
|
|
|
|
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
|
|
|
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
|
|
|
|
vit_embeds = self.mlp1(vit_embeds)
|
|
|
|
|
|
|
|
|
|
# For mnn's embedding, the order is (seq, batch, hidden)
|
|
|
|
|
vit_embeds = vit_embeds.permute(1, 0, 2)
|
|
|
|
|
return vit_embeds
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
|
|
|
|
IMAGENET_STD = [0.229, 0.224, 0.225]
|
|
|
|
|
for i in range(3):
|
|
|
|
|
IMAGENET_MEAN[i] = IMAGENET_MEAN[i] * 255.0
|
|
|
|
|
IMAGENET_STD[i] = 1.0 / IMAGENET_STD[i] / 255.0
|
|
|
|
|
self.llm_config['image_mean'] = IMAGENET_MEAN
|
|
|
|
|
self.llm_config['image_norm'] = IMAGENET_STD
|
|
|
|
|
self.llm_config['image_size_unit'] = 14
|
|
|
|
|
def export(self, onnx_path):
|
|
|
|
|
input_images = torch.randn((1, 3, self.image_size, self.image_size), dtype=torch.float32)
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (input_images),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['input_images'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"input_images": { 0: "size", 2: "height", 3: "width"},
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
def forward(self, images):
|
|
|
|
|
return self.extract_feature(images)
|
|
|
|
|
|
2025-01-22 14:47:50 +08:00
|
|
|
class QwenVision(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__(visual, base)
|
2025-05-23 15:21:41 +08:00
|
|
|
self.quant_bit = 16
|
2025-01-22 14:47:50 +08:00
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.image_start_id = self.config.visual['image_start_id']
|
|
|
|
|
self.image_size = self.config.visual['image_size']
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
|
self.llm_config['vision_start'] = self.tokenizer.img_start_id
|
|
|
|
|
self.llm_config['vision_end'] = self.tokenizer.img_end_id
|
|
|
|
|
self.llm_config['image_pad'] = self.tokenizer.img_pad_id
|
|
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
@spinner_run(f'export visual to ')
|
2025-01-22 14:47:50 +08:00
|
|
|
def export(self, onnx_path):
|
|
|
|
|
input_images = torch.randn((1, 3, self.image_size, self.image_size))
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (input_images),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['input_images'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"input_images": { 0: "size" },
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
|
return self.visual(images).transpose(1, 0)
|
|
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
|
if not torch.any(input_ids == self.image_start_id):
|
|
|
|
|
return self.embed_(input_ids)
|
|
|
|
|
bos_pos = torch.where(input_ids == self.image_start_id)
|
|
|
|
|
eos_pos = torch.where(input_ids == self.image_start_id + 1)
|
|
|
|
|
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
|
|
|
|
|
images = []
|
|
|
|
|
for i, a, b in img_pos:
|
|
|
|
|
image = input_ids[i][a + 1 : b - 1].tolist()
|
|
|
|
|
image = image[ : image.index(self.image_start_id + 2)]
|
|
|
|
|
images.append(bytes(image).decode('utf-8'))
|
|
|
|
|
images = self.visual.encode(images).transpose(1, 0)
|
|
|
|
|
hidden_states = self.embed_(input_ids)
|
|
|
|
|
for idx, (i, a, b) in enumerate(img_pos):
|
|
|
|
|
hidden_states[a + 1 : b, i] = images[:, idx]
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
class Qwen2Vision(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
self.temporal_patch_size = 2
|
|
|
|
|
self.patch_size = 14
|
|
|
|
|
self.merge_size = 2
|
|
|
|
|
self.image_height = 420
|
|
|
|
|
self.image_width = 420
|
2025-05-08 12:39:44 +08:00
|
|
|
self.image_embeds = []
|
|
|
|
|
self.image_grid_thw = []
|
2025-01-22 14:47:50 +08:00
|
|
|
super().__init__(visual, base)
|
2025-05-23 15:21:41 +08:00
|
|
|
self.quant_bit = 4
|
2025-01-22 14:47:50 +08:00
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.vision_start_id = self.config.vision_start_token_id
|
|
|
|
|
self.vision_end_id = self.config.vision_end_token_id
|
|
|
|
|
self.image_pad_id = self.config.image_token_id
|
|
|
|
|
self.llm_config['image_size'] = self.image_height
|
|
|
|
|
self.llm_config['vision_start'] = self.vision_start_id
|
|
|
|
|
self.llm_config['vision_end'] = self.vision_end_id
|
|
|
|
|
self.llm_config['image_pad'] = self.image_pad_id
|
2025-05-08 12:39:44 +08:00
|
|
|
self.vision_start_token = '<|vision_start|>'
|
|
|
|
|
self.vision_end_token = '<|vision_end|>'
|
|
|
|
|
self.image_pad_token = '<|image_pad|>'
|
2025-01-22 14:47:50 +08:00
|
|
|
# load model
|
|
|
|
|
config = self.visual.config
|
2025-02-17 16:54:34 +08:00
|
|
|
if hasattr(config, "embed_dim"):
|
|
|
|
|
self.hidden_size = config.embed_dim
|
|
|
|
|
else:
|
|
|
|
|
self.hidden_size = config.hidden_size
|
2025-01-22 14:47:50 +08:00
|
|
|
self.num_attention_heads = config.num_heads
|
|
|
|
|
self.num_key_value_heads = config.num_heads
|
|
|
|
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
|
|
|
self.rope_theta = 10000.0
|
|
|
|
|
self.rotary_dim = self.head_dim // 2
|
|
|
|
|
self.rotary = VisionRotary(self)
|
|
|
|
|
self.model_map = {
|
|
|
|
|
'decoder': {
|
|
|
|
|
'self_attn': 'attn',
|
|
|
|
|
'mlp': 'mlp',
|
|
|
|
|
'input_layernorm': 'norm1',
|
|
|
|
|
'post_attention_layernorm': 'norm2'
|
|
|
|
|
},
|
|
|
|
|
'attention': {
|
|
|
|
|
'qkv_proj': 'qkv',
|
|
|
|
|
'o_proj': 'proj'
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
self.patch_embed = self.visual.patch_embed
|
|
|
|
|
self.blocks = []
|
|
|
|
|
for block in self.visual.blocks.children():
|
|
|
|
|
layer_id = len(self.blocks)
|
|
|
|
|
self.blocks.append(Decoder(block, layer_id, self))
|
|
|
|
|
self.merger = self.visual.merger
|
|
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
|
if '<img>' in prompt and '</img>' in prompt:
|
|
|
|
|
import re
|
|
|
|
|
import requests
|
|
|
|
|
from PIL import Image
|
|
|
|
|
pattern = r'(<img>.*?</img>)'
|
|
|
|
|
parts = re.split(pattern, prompt)
|
|
|
|
|
txt_prompt = ''
|
|
|
|
|
for part in parts:
|
|
|
|
|
if re.match(pattern, part):
|
|
|
|
|
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
|
|
|
|
|
# find <hw></hw> in image_content
|
|
|
|
|
match = re.search(r'<hw>(.*?)</hw>', img_content)
|
2025-05-08 12:39:44 +08:00
|
|
|
if match:
|
|
|
|
|
img_content = img_content[:match.start()] + img_content[match.end():]
|
|
|
|
|
hw = match.group(1).split(',')
|
|
|
|
|
self.image_height, self.image_width = int(hw[0]), int(hw[1])
|
2025-01-22 14:47:50 +08:00
|
|
|
if img_content.startswith('http://') or img_content.startswith('https://'):
|
|
|
|
|
image_obj = Image.open(requests.get(img_content, stream=True).raw)
|
2025-05-08 12:39:44 +08:00
|
|
|
else:
|
|
|
|
|
image_obj = Image.open(img_content)
|
2025-01-22 14:47:50 +08:00
|
|
|
img_pad_len = self.img_process(image_obj)
|
2025-05-08 12:39:44 +08:00
|
|
|
img_pad_str = self.image_pad_token * img_pad_len
|
|
|
|
|
img_str = f'{self.vision_start_token}{img_pad_str}{self.vision_end_token}'
|
2025-01-22 14:47:50 +08:00
|
|
|
txt_prompt += img_str
|
|
|
|
|
else:
|
|
|
|
|
txt_prompt += part
|
|
|
|
|
else:
|
|
|
|
|
txt_prompt = prompt
|
|
|
|
|
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
|
|
|
|
|
return input_ids
|
|
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
def get_position_ids(self, input_ids, seq_len, token_len):
|
|
|
|
|
if token_len:
|
|
|
|
|
position_ids = torch.tensor([[seq_len - 1]] * 3, dtype=torch.int)
|
|
|
|
|
return position_ids
|
|
|
|
|
input_ids = input_ids.flatten()
|
|
|
|
|
txt_len, vision_idx, cur_idx = 0, 0, 0
|
|
|
|
|
position_ids_list = []
|
|
|
|
|
for i, token in enumerate(input_ids):
|
|
|
|
|
if token != self.image_pad_id:
|
|
|
|
|
txt_len += 1
|
|
|
|
|
if token == self.vision_start_id:
|
|
|
|
|
text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int)
|
|
|
|
|
cur_idx += txt_len
|
|
|
|
|
txt_len = 0
|
|
|
|
|
position_ids_list.append(torch.stack([text_index, text_index, text_index]))
|
|
|
|
|
elif token == self.vision_end_id:
|
|
|
|
|
t, h, w = self.image_grid_thw[vision_idx]
|
|
|
|
|
h = h // self.merge_size
|
|
|
|
|
w = w // self.merge_size
|
|
|
|
|
t_index = torch.arange(t).view(-1, 1).expand(-1, h * w).flatten()
|
|
|
|
|
h_index = torch.arange(h).view(1, -1, 1).expand(t, -1, w).flatten()
|
|
|
|
|
w_index = torch.arange(w).view(1, 1, -1).expand(t, h, -1).flatten()
|
|
|
|
|
position_ids_list.append(torch.stack([t_index, h_index, w_index]) + cur_idx)
|
|
|
|
|
cur_idx += w
|
|
|
|
|
vision_idx += 1
|
|
|
|
|
if txt_len > 0:
|
|
|
|
|
text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int)
|
|
|
|
|
position_ids_list.append(torch.stack([text_index, text_index, text_index]))
|
|
|
|
|
position_ids = torch.cat(position_ids_list, dim=1)
|
|
|
|
|
return position_ids
|
2025-01-22 14:47:50 +08:00
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
def vision_position_ids(self, grid_thw):
|
|
|
|
|
pos_ids = []
|
|
|
|
|
for t, h, w in grid_thw:
|
|
|
|
|
llm_h, llm_w = h // self.merge_size, w // self.merge_size
|
|
|
|
|
# compute pos_ids
|
|
|
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
|
|
|
hpos_ids = hpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size)
|
|
|
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
|
|
|
hpos_ids = hpos_ids.flatten()
|
|
|
|
|
|
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
|
|
|
wpos_ids = wpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size)
|
|
|
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
|
|
|
wpos_ids = wpos_ids.flatten()
|
|
|
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids]))
|
|
|
|
|
position_ids = torch.cat(pos_ids, dim=0)
|
|
|
|
|
return position_ids
|
|
|
|
|
|
|
|
|
|
def vision_attention_mask(self, grid_thw, cu_window_seqlens = None):
|
|
|
|
|
seq_len = grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]
|
|
|
|
|
if cu_window_seqlens is None:
|
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0)
|
|
|
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
|
|
|
else:
|
|
|
|
|
cu_seqlens = cu_window_seqlens
|
|
|
|
|
attention_mask = torch.full([1, seq_len, seq_len], torch.finfo(torch.float32).min)
|
|
|
|
|
for i in range(1, len(cu_seqlens)):
|
|
|
|
|
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
|
|
|
return attention_mask
|
|
|
|
|
|
|
|
|
|
def vision_reshape(self, images):
|
2025-01-22 14:47:50 +08:00
|
|
|
images = [images] * self.temporal_patch_size
|
|
|
|
|
patches = torch.concat(images, axis=0)
|
|
|
|
|
_, channel, height, width = patches.shape
|
|
|
|
|
grid_t = patches.shape[0] // self.temporal_patch_size
|
|
|
|
|
grid_h, grid_w = height // self.patch_size, width // self.patch_size
|
|
|
|
|
patches = patches.reshape(
|
|
|
|
|
grid_t,
|
|
|
|
|
self.temporal_patch_size,
|
|
|
|
|
channel,
|
|
|
|
|
grid_h // self.merge_size,
|
|
|
|
|
self.merge_size,
|
|
|
|
|
self.patch_size,
|
|
|
|
|
grid_w // self.merge_size,
|
|
|
|
|
self.merge_size,
|
|
|
|
|
self.patch_size,
|
|
|
|
|
)
|
|
|
|
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
|
|
|
|
flatten_patches = patches.reshape(
|
|
|
|
|
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
|
|
|
|
)
|
2025-05-08 12:39:44 +08:00
|
|
|
grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
|
|
|
|
|
self.image_grid_thw.append([grid_t, grid_h, grid_w])
|
|
|
|
|
return flatten_patches, grid_thw
|
2025-01-22 14:47:50 +08:00
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
def images_forward(self, images):
|
|
|
|
|
flatten_patches, grid_thw = self.vision_reshape(images)
|
|
|
|
|
position_ids = self.vision_position_ids(grid_thw)
|
|
|
|
|
attention_mask = self.vision_attention_mask(grid_thw)
|
2025-01-22 14:47:50 +08:00
|
|
|
return self.forward(flatten_patches, position_ids, attention_mask)
|
|
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
def forward(self, flatten_patches, position_ids, attention_mask):
|
|
|
|
|
rotary_pos_emb = self.rotary(position_ids)
|
|
|
|
|
hidden_states = self.patch_embed(flatten_patches)
|
|
|
|
|
if rotary_pos_emb.dtype != hidden_states.dtype:
|
|
|
|
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype)
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
|
hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask)
|
|
|
|
|
image_embeds = self.merger(hidden_states)
|
|
|
|
|
image_embeds = image_embeds.unsqueeze(1)
|
|
|
|
|
return image_embeds
|
|
|
|
|
|
2025-01-22 14:47:50 +08:00
|
|
|
def smart_resize(self, height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280):
|
|
|
|
|
if height < factor or width < factor:
|
|
|
|
|
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
|
|
|
|
elif max(height, width) / min(height, width) > 200:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
|
|
|
|
)
|
|
|
|
|
h_bar = round(height / factor) * factor
|
|
|
|
|
w_bar = round(width / factor) * factor
|
|
|
|
|
if h_bar * w_bar > max_pixels:
|
|
|
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
|
|
|
h_bar = math.floor(height / beta / factor) * factor
|
|
|
|
|
w_bar = math.floor(width / beta / factor) * factor
|
|
|
|
|
elif h_bar * w_bar < min_pixels:
|
|
|
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
|
|
|
h_bar = math.ceil(height * beta / factor) * factor
|
|
|
|
|
w_bar = math.ceil(width * beta / factor) * factor
|
|
|
|
|
return h_bar, w_bar
|
|
|
|
|
|
|
|
|
|
def img_process(self, image):
|
|
|
|
|
from transformers.image_transforms import (
|
|
|
|
|
convert_to_rgb,
|
|
|
|
|
resize,
|
|
|
|
|
rescale,
|
|
|
|
|
normalize
|
|
|
|
|
)
|
|
|
|
|
from transformers.image_utils import (
|
|
|
|
|
OPENAI_CLIP_MEAN,
|
|
|
|
|
OPENAI_CLIP_STD,
|
|
|
|
|
PILImageResampling,
|
|
|
|
|
infer_channel_dimension_format,
|
|
|
|
|
to_numpy_array
|
|
|
|
|
)
|
|
|
|
|
image = convert_to_rgb(image)
|
|
|
|
|
image = to_numpy_array(image)
|
|
|
|
|
resized_height, resized_width = self.smart_resize(self.image_height, self.image_width)
|
|
|
|
|
format = infer_channel_dimension_format(image)
|
|
|
|
|
resample = PILImageResampling.BICUBIC
|
|
|
|
|
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
|
|
|
|
|
image = rescale(image, scale=1 / 255.0, input_data_format=format)
|
|
|
|
|
image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
|
|
|
|
|
image = np.expand_dims(image, [0])
|
|
|
|
|
image = image.transpose(0, 3, 1, 2)
|
|
|
|
|
image = torch.from_numpy(image)
|
2025-05-08 12:39:44 +08:00
|
|
|
image_embed = self.images_forward(image)
|
|
|
|
|
self.image_embeds.append(image_embed)
|
|
|
|
|
return image_embed.shape[0]
|
2025-01-22 14:47:50 +08:00
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
|
input_embeds = self.embed_(input_ids)
|
2025-05-08 12:39:44 +08:00
|
|
|
if self.image_embeds is not None and len(self.image_embeds) > 0:
|
2025-01-22 14:47:50 +08:00
|
|
|
image_mask = (input_ids == self.image_pad_id).squeeze()
|
2025-05-08 12:39:44 +08:00
|
|
|
input_embeds[image_mask] = torch.concat(self.image_embeds, dim=0).to(input_embeds.dtype)
|
2025-01-22 14:47:50 +08:00
|
|
|
return input_embeds
|
|
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
@spinner_run(f'export visual to ')
|
2025-01-22 14:47:50 +08:00
|
|
|
def export(self, onnx_path):
|
|
|
|
|
patch = torch.randn([900, 1176])
|
|
|
|
|
posision_ids = torch.zeros([2, 900], dtype=torch.int32)
|
|
|
|
|
attention_mask = torch.zeros([1, 900, 900], dtype=torch.float)
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (patch, posision_ids, attention_mask),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['patches', 'position_ids', 'attention_mask'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"patches": { 0: "size" },
|
|
|
|
|
"position_ids": { 1: "size" },
|
|
|
|
|
"attention_mask": { 1: "size", 2: "size" }
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
|
2025-05-23 15:21:41 +08:00
|
|
|
class Gemma3Vision(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
# read from gemma3_map
|
|
|
|
|
self.image_size = base.image_size
|
|
|
|
|
# embedding functions
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.quant_bit = 8
|
|
|
|
|
self.vision_tower = base.vision_tower
|
|
|
|
|
self.multi_modal_projector = base.multi_modal_projector.float()
|
|
|
|
|
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.image_mean_from_preprcessor_config = [0.5, 0.5, 0.5]
|
|
|
|
|
self.image_std_from_preprcessor_config = [0.5, 0.5, 0.5]
|
|
|
|
|
for i in range(3):
|
|
|
|
|
self.image_mean_from_preprcessor_config[i] = self.image_mean_from_preprcessor_config[i] * 255.0
|
|
|
|
|
self.image_std_from_preprcessor_config[i] = 1.0 / self.image_std_from_preprcessor_config[i] / 255.0
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
self.llm_config['image_mean'] = self.image_mean_from_preprcessor_config
|
|
|
|
|
self.llm_config['image_norm'] = self.image_std_from_preprcessor_config
|
|
|
|
|
self.llm_config['vision_start'] = self.config.boi_token_index
|
|
|
|
|
self.llm_config['vision_end'] = self.config.eoi_token_index
|
|
|
|
|
self.llm_config['image_pad'] = self.config.image_token_index
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
|
|
|
|
|
|
def forward(self, pixel_values):
|
|
|
|
|
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
|
|
|
|
image_features = self.multi_modal_projector(vision_outputs)
|
|
|
|
|
image_features_transpose = image_features.permute(1, 0, 2)
|
|
|
|
|
return image_features_transpose
|
|
|
|
|
|
|
|
|
|
def export(self, onnx_path):
|
|
|
|
|
input_images = torch.randn((1, 3, self.image_size, self.image_size))
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (input_images),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['input_images'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"input_images": { 0: "size", 2: "height", 3: "width"},
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
|
|
def embed(self, input_ids):
|
|
|
|
|
txt_embeds = self.embed_(input_ids)
|
|
|
|
|
return txt_embeds
|
|
|
|
|
|
2025-05-08 12:39:44 +08:00
|
|
|
class Qwen2_5Vision(Qwen2Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.merge_unit = self.merge_size * self.merge_size
|
|
|
|
|
self.window_size = visual.window_size
|
|
|
|
|
self.fullatt_block_indexes = visual.fullatt_block_indexes
|
|
|
|
|
|
|
|
|
|
def get_window_index(self, grid_thw):
|
|
|
|
|
window_index: list = []
|
|
|
|
|
cu_window_seqlens: list = [0]
|
|
|
|
|
window_index_id = 0
|
|
|
|
|
vit_merger_window_size = self.window_size // self.merge_size // self.patch_size
|
|
|
|
|
|
|
|
|
|
for grid_t, grid_h, grid_w in grid_thw:
|
|
|
|
|
llm_grid_h, llm_grid_w = (
|
|
|
|
|
grid_h // self.merge_size,
|
|
|
|
|
grid_w // self.merge_size,
|
|
|
|
|
)
|
|
|
|
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
|
|
|
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
|
|
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
|
|
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
|
|
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
|
|
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
|
|
|
index_padded = index_padded.reshape(
|
|
|
|
|
grid_t,
|
|
|
|
|
num_windows_h,
|
|
|
|
|
vit_merger_window_size,
|
|
|
|
|
num_windows_w,
|
|
|
|
|
vit_merger_window_size,
|
|
|
|
|
)
|
|
|
|
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
|
|
|
grid_t,
|
|
|
|
|
num_windows_h * num_windows_w,
|
|
|
|
|
vit_merger_window_size,
|
|
|
|
|
vit_merger_window_size,
|
|
|
|
|
)
|
|
|
|
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
|
|
|
index_padded = index_padded.reshape(-1)
|
|
|
|
|
index_new = index_padded[index_padded != -100]
|
|
|
|
|
window_index.append(index_new + window_index_id)
|
|
|
|
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1]
|
|
|
|
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
|
|
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
|
|
|
window_index = torch.cat(window_index, dim=0)
|
|
|
|
|
return window_index, cu_window_seqlens
|
|
|
|
|
|
|
|
|
|
def images_forward(self, images):
|
|
|
|
|
flatten_patches, grid_thw = self.vision_reshape(images)
|
|
|
|
|
position_ids = self.vision_position_ids(grid_thw)
|
|
|
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
|
|
|
normal_attention_mask = self.vision_attention_mask(grid_thw)
|
|
|
|
|
fullatt_attention_mask = self.vision_attention_mask(grid_thw, cu_window_seqlens)
|
|
|
|
|
attention_mask = torch.stack([normal_attention_mask, fullatt_attention_mask], dim=0)
|
|
|
|
|
return self.forward(flatten_patches, position_ids, attention_mask, window_index)
|
|
|
|
|
|
|
|
|
|
def forward(self, flatten_patches, position_ids, attention_mask, window_index):
|
|
|
|
|
hidden_states = self.patch_embed(flatten_patches)
|
|
|
|
|
seq_len, _ = hidden_states.size()
|
|
|
|
|
position_ids = position_ids.reshape(2, seq_len // self.merge_unit, self.merge_unit)
|
|
|
|
|
position_ids = position_ids[:, window_index, :]
|
|
|
|
|
position_ids = position_ids.reshape(2, seq_len)
|
|
|
|
|
rotary_pos_emb = self.rotary(position_ids)
|
|
|
|
|
if rotary_pos_emb.dtype != hidden_states.dtype:
|
|
|
|
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype)
|
|
|
|
|
hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1)
|
|
|
|
|
hidden_states = hidden_states[window_index, :, :]
|
|
|
|
|
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
|
|
|
for layer_num, blk in enumerate(self.blocks):
|
|
|
|
|
if layer_num in self.fullatt_block_indexes:
|
|
|
|
|
attention_mask_now = attention_mask[0]
|
|
|
|
|
else:
|
|
|
|
|
attention_mask_now = attention_mask[1]
|
|
|
|
|
hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask_now)
|
|
|
|
|
image_embeds = self.merger(hidden_states)
|
|
|
|
|
reverse_indices = torch.argsort(window_index)
|
|
|
|
|
image_embeds = image_embeds[reverse_indices, :]
|
|
|
|
|
image_embeds = image_embeds.unsqueeze(1)
|
|
|
|
|
return image_embeds
|
|
|
|
|
|
|
|
|
|
@spinner_run(f'export visual to ')
|
|
|
|
|
def export(self, onnx_path):
|
|
|
|
|
patch = torch.randn([400, 1176])
|
|
|
|
|
posision_ids = torch.zeros([2, 400], dtype=torch.int32)
|
|
|
|
|
attention_mask = torch.zeros([2, 1, 400, 400], dtype=torch.float)
|
|
|
|
|
window_index = torch.arange(100, dtype=torch.int32)
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (patch, posision_ids, attention_mask, window_index),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['patches', 'position_ids', 'attention_mask', 'window_index'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"patches": { 0: "size" },
|
|
|
|
|
"position_ids": { 1: "size" },
|
|
|
|
|
"attention_mask": { 2: "size", 3: "size" },
|
|
|
|
|
"window_index": { 0: "size" }
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
|
|
class Qwen2_5OmniVision(Qwen2_5Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
self.temporal_patch_size = 2
|
|
|
|
|
self.patch_size = 14
|
|
|
|
|
self.merge_size = 2
|
|
|
|
|
self.image_height = 420
|
|
|
|
|
self.image_width = 420
|
|
|
|
|
self.image_embeds = None
|
|
|
|
|
super().__init__(visual, base)
|
2025-05-23 15:21:41 +08:00
|
|
|
self.quant_bit = 8
|
2025-05-08 12:39:44 +08:00
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.config = self.config.thinker_config
|
|
|
|
|
self.vision_start_id = self.config.vision_start_token_id
|
|
|
|
|
self.vision_end_id = self.config.vision_end_token_id
|
|
|
|
|
self.image_pad_id = self.config.image_token_index
|
|
|
|
|
self.llm_config['image_size'] = self.image_height
|
|
|
|
|
self.llm_config['vision_start'] = self.vision_start_id
|
|
|
|
|
self.llm_config['vision_end'] = self.vision_end_id
|
|
|
|
|
self.llm_config['image_pad'] = self.image_pad_id
|
|
|
|
|
self.vision_start_token = '<|vision_bos|>'
|
|
|
|
|
self.vision_end_token = '<|vision_eos|>'
|
|
|
|
|
self.image_pad_token = '<|IMAGE|>'
|
|
|
|
|
# load model
|
|
|
|
|
config = self.visual.config
|
|
|
|
|
if hasattr(config, "embed_dim"):
|
|
|
|
|
self.hidden_size = config.embed_dim
|
|
|
|
|
else:
|
|
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
|
self.num_attention_heads = config.num_heads
|
|
|
|
|
self.num_key_value_heads = config.num_heads
|
|
|
|
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
|
|
|
self.rope_theta = 10000.0
|
|
|
|
|
self.rotary_dim = self.head_dim // 2
|
|
|
|
|
self.rotary = VisionRotary(self)
|
|
|
|
|
self.model_map = {
|
|
|
|
|
'decoder': {
|
|
|
|
|
'self_attn': 'attn',
|
|
|
|
|
'mlp': 'mlp',
|
|
|
|
|
'input_layernorm': 'norm1',
|
|
|
|
|
'post_attention_layernorm': 'norm2'
|
|
|
|
|
},
|
|
|
|
|
'attention': {
|
|
|
|
|
'q_proj': 'q',
|
|
|
|
|
'k_proj': 'k',
|
|
|
|
|
'v_proj': 'v',
|
|
|
|
|
'o_proj': 'proj'
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
self.patch_embed = self.visual.patch_embed
|
|
|
|
|
self.blocks = []
|
|
|
|
|
for block in self.visual.blocks.children():
|
|
|
|
|
layer_id = len(self.blocks)
|
|
|
|
|
self.blocks.append(Decoder(block, layer_id, self))
|
|
|
|
|
self.merger = self.visual.merger
|
|
|
|
|
|
2025-01-22 14:47:50 +08:00
|
|
|
class MllamaVision(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.multi_modal_projector = base.multi_modal_projector
|
|
|
|
|
self.image_objs = []
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
self.llm_config['image_size'] = self.config.vision_config.image_size
|
|
|
|
|
self.image_size = self.config.vision_config.image_size
|
|
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
|
if '<img>' in prompt and '</img>' in prompt:
|
|
|
|
|
import re
|
|
|
|
|
import requests
|
|
|
|
|
from PIL import Image
|
|
|
|
|
pattern = r'(<img>.*?</img>)'
|
|
|
|
|
parts = re.split(pattern, prompt)
|
|
|
|
|
txt_prompt = ''
|
|
|
|
|
for part in parts:
|
|
|
|
|
if re.match(pattern, part):
|
|
|
|
|
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
|
|
|
|
|
if img_content.startswith('http://') or img_content.startswith('https://'):
|
|
|
|
|
self.image_objs.append(Image.open(requests.get(img_content, stream=True).raw))
|
2025-05-08 12:39:44 +08:00
|
|
|
else:
|
|
|
|
|
self.image_objs.append(Image.open(img_content))
|
2025-01-22 14:47:50 +08:00
|
|
|
txt_prompt += '<|image|>'
|
|
|
|
|
else:
|
|
|
|
|
txt_prompt += part
|
|
|
|
|
else:
|
|
|
|
|
txt_prompt = prompt
|
|
|
|
|
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
|
|
|
|
|
# image process
|
|
|
|
|
for img in self.image_objs:
|
|
|
|
|
self.img_process(img)
|
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
|
def img_process(self, image):
|
|
|
|
|
self.image_size = 560
|
|
|
|
|
resized_height = self.image_size
|
|
|
|
|
resized_width = self.image_size
|
|
|
|
|
from transformers.image_transforms import (
|
|
|
|
|
convert_to_rgb,
|
|
|
|
|
resize,
|
|
|
|
|
rescale,
|
|
|
|
|
normalize
|
|
|
|
|
)
|
|
|
|
|
from transformers.image_utils import (
|
|
|
|
|
OPENAI_CLIP_MEAN,
|
|
|
|
|
OPENAI_CLIP_STD,
|
|
|
|
|
PILImageResampling,
|
|
|
|
|
infer_channel_dimension_format,
|
|
|
|
|
to_numpy_array
|
|
|
|
|
)
|
|
|
|
|
image = convert_to_rgb(image)
|
|
|
|
|
image = to_numpy_array(image)
|
|
|
|
|
format = infer_channel_dimension_format(image)
|
|
|
|
|
resample = PILImageResampling.BICUBIC
|
|
|
|
|
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
|
|
|
|
|
image = rescale(image, scale=1 / 255.0, input_data_format=format)
|
|
|
|
|
image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
|
|
|
|
|
image = image.transpose(2, 0, 1)
|
|
|
|
|
image = np.expand_dims(image, [0, 1, 2])
|
|
|
|
|
pad_val = np.zeros_like(image)
|
|
|
|
|
image = np.concatenate([image, pad_val, pad_val, pad_val], axis=2)
|
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
|
self.cross_attention_states = self.forward(image)
|
|
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
|
aspect_ratio_ids = torch.tensor([[1]])
|
|
|
|
|
aspect_ratio_mask = torch.tensor([[[1, 0, 0, 0]]])
|
|
|
|
|
vision_outputs = self.visual(images, aspect_ratio_ids, aspect_ratio_mask)
|
|
|
|
|
cross_attention_states = vision_outputs[0]
|
|
|
|
|
cross_attention_states = cross_attention_states.type(self.multi_modal_projector.weight.dtype)
|
|
|
|
|
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
|
|
|
|
|
-1, cross_attention_states.shape[-2], self.hidden_size)
|
|
|
|
|
return cross_attention_states
|
|
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
|
txt_embeds = self.embed_(input_ids)
|
2025-05-23 15:21:41 +08:00
|
|
|
return txt_embeds
|
|
|
|
|
|
|
|
|
|
# SmolVLM & SmolVLM2
|
|
|
|
|
class Idefics3Vision(Vision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
self.patch_size = visual.config.max_image_size['longest_edge']
|
|
|
|
|
self.image_max_size = visual.config.size['longest_edge']
|
|
|
|
|
self.image_height = self.patch_size
|
|
|
|
|
self.image_width = self.image_height
|
|
|
|
|
self.image_embeds = []
|
|
|
|
|
self.image_mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
|
|
|
|
self.image_norm = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.connector = base.model.model.connector.float()
|
|
|
|
|
self.visual = self.visual.float()
|
|
|
|
|
self.quant_bit = 8
|
|
|
|
|
self.transformer_fuse = False
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.vision_start_token = '<fake_token_around_image>'
|
|
|
|
|
self.vision_end_token = '<fake_token_around_image>'
|
|
|
|
|
self.image_pad_token = '<image>'
|
|
|
|
|
self.global_image_token = '<global-img>'
|
|
|
|
|
self.vision_start_id = self.tokenizer.encode(self.vision_start_token)[0]
|
|
|
|
|
self.vision_end_id = self.vision_start_id
|
|
|
|
|
self.image_pad_id = self.tokenizer.encode(self.image_pad_token)[0]
|
|
|
|
|
self.global_image_id = self.tokenizer.encode(self.global_image_token)[0]
|
|
|
|
|
self.llm_config['image_size_unit'] = self.patch_size
|
|
|
|
|
self.llm_config['image_size'] = self.image_height
|
|
|
|
|
self.llm_config['image_max_size'] = self.image_max_size
|
|
|
|
|
self.llm_config['vision_start'] = self.vision_start_id
|
|
|
|
|
self.llm_config['vision_end'] = self.vision_end_id
|
|
|
|
|
self.llm_config['image_pad'] = self.image_pad_id
|
|
|
|
|
self.llm_config['global_image'] = self.global_image_id
|
|
|
|
|
# load model
|
|
|
|
|
self.patch_embedding = self.visual.embeddings.patch_embedding
|
|
|
|
|
self.position_embedding = self.visual.embeddings.position_embedding
|
|
|
|
|
self.encoder = self.visual.encoder
|
|
|
|
|
self.post_layernorm = self.visual.post_layernorm
|
|
|
|
|
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
image_mean = self.image_mean * 255.0
|
|
|
|
|
image_norm = 1 / (self.image_norm * 255.0)
|
|
|
|
|
self.llm_config['image_mean'] = image_mean.tolist()
|
|
|
|
|
self.llm_config['image_norm'] = image_norm.tolist()
|
|
|
|
|
|
|
|
|
|
def str_to_ids(self, prompt):
|
|
|
|
|
if '<img>' in prompt and '</img>' in prompt:
|
|
|
|
|
import re
|
|
|
|
|
import requests
|
|
|
|
|
from PIL import Image
|
|
|
|
|
pattern = r'(<img>.*?</img>)'
|
|
|
|
|
parts = re.split(pattern, prompt)
|
|
|
|
|
txt_prompt = ''
|
|
|
|
|
for part in parts:
|
|
|
|
|
if re.match(pattern, part):
|
|
|
|
|
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
|
|
|
|
|
# find <hw></hw> in image_content
|
|
|
|
|
match = re.search(r'<hw>(.*?)</hw>', img_content)
|
|
|
|
|
if match:
|
|
|
|
|
img_content = img_content[:match.start()] + img_content[match.end():]
|
|
|
|
|
hw = match.group(1).split(',')
|
|
|
|
|
self.image_height, self.image_width = int(hw[0]), int(hw[1])
|
|
|
|
|
if img_content.startswith('http://') or img_content.startswith('https://'):
|
|
|
|
|
image_obj = Image.open(requests.get(img_content, stream=True).raw)
|
|
|
|
|
else:
|
|
|
|
|
image_obj = Image.open(img_content)
|
|
|
|
|
img_pad_len, grid_h, grid_w = self.img_process(image_obj)
|
|
|
|
|
img_pad_str = self.image_pad_token * img_pad_len
|
|
|
|
|
if grid_h > 0 and grid_w > 0:
|
|
|
|
|
for n_h in range(grid_h):
|
|
|
|
|
for n_w in range(grid_w):
|
|
|
|
|
txt_prompt += (
|
|
|
|
|
f"{self.vision_start_token}" + f"<row_{n_h + 1}_col_{n_w + 1}>" + img_pad_str
|
|
|
|
|
)
|
|
|
|
|
txt_prompt += "\n"
|
|
|
|
|
txt_prompt += "\n"
|
|
|
|
|
txt_prompt += (f'{self.vision_start_token}{self.global_image_token}{img_pad_str}{self.vision_end_token}')
|
|
|
|
|
else:
|
|
|
|
|
txt_prompt += part
|
|
|
|
|
else:
|
|
|
|
|
txt_prompt = prompt
|
|
|
|
|
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
|
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
|
def vision_reshape(self, images):
|
|
|
|
|
batch, channel, height, width = images.shape
|
|
|
|
|
grid_h, grid_w = height // self.patch_size, width // self.patch_size
|
|
|
|
|
patches = images.reshape(
|
|
|
|
|
batch,
|
|
|
|
|
channel,
|
|
|
|
|
grid_h,
|
|
|
|
|
self.patch_size,
|
|
|
|
|
grid_w,
|
|
|
|
|
self.patch_size,
|
|
|
|
|
)
|
|
|
|
|
patches = patches.permute(0, 2, 4, 1, 3, 5)
|
|
|
|
|
flatten_patches = patches.reshape(
|
|
|
|
|
batch * grid_h * grid_w, channel, self.patch_size, self.patch_size
|
|
|
|
|
)
|
|
|
|
|
return flatten_patches, grid_h, grid_w
|
|
|
|
|
|
|
|
|
|
def images_forward(self, images):
|
|
|
|
|
return self.forward(images)
|
|
|
|
|
|
|
|
|
|
def forward(self, pixel_values):
|
|
|
|
|
patch_embeds = self.patch_embedding(pixel_values)
|
|
|
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
|
embeddings = embeddings + self.position_embedding.weight
|
|
|
|
|
encoder_output = self.encoder(embeddings)[0]
|
|
|
|
|
last_hidden_state = self.post_layernorm(encoder_output)
|
|
|
|
|
image_hidden_states = self.connector(last_hidden_state)
|
|
|
|
|
image_hidden_states = image_hidden_states.unsqueeze(2)
|
|
|
|
|
return image_hidden_states
|
|
|
|
|
|
|
|
|
|
def get_size(self, height: int, width: int):
|
|
|
|
|
vision_encoder_max_size = self.patch_size
|
|
|
|
|
aspect_ratio = width / height
|
|
|
|
|
if width >= height:
|
|
|
|
|
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
|
|
|
|
height = int(width / aspect_ratio)
|
|
|
|
|
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
|
|
|
|
elif height > width:
|
|
|
|
|
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
|
|
|
|
width = int(height * aspect_ratio)
|
|
|
|
|
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
|
|
|
|
if height > self.image_max_size:
|
|
|
|
|
height = self.image_max_size
|
|
|
|
|
if width > self.image_max_size:
|
|
|
|
|
width = self.image_max_size
|
|
|
|
|
return height, width
|
|
|
|
|
|
|
|
|
|
def img_process(self, image):
|
|
|
|
|
from transformers.image_transforms import (
|
|
|
|
|
convert_to_rgb,
|
|
|
|
|
resize,
|
|
|
|
|
rescale,
|
|
|
|
|
normalize
|
|
|
|
|
)
|
|
|
|
|
from transformers.image_utils import (
|
|
|
|
|
PILImageResampling,
|
|
|
|
|
infer_channel_dimension_format,
|
|
|
|
|
to_numpy_array
|
|
|
|
|
)
|
|
|
|
|
image = convert_to_rgb(image)
|
|
|
|
|
image = to_numpy_array(image)
|
|
|
|
|
resized_height, resized_width = self.get_size(self.image_height, self.image_width)
|
|
|
|
|
resized_height, resized_width = 1, 1
|
|
|
|
|
format = infer_channel_dimension_format(image)
|
|
|
|
|
resample = PILImageResampling.BICUBIC
|
|
|
|
|
global_image = resize(image, size=(self.patch_size, self.patch_size), resample=resample, input_data_format=format)
|
|
|
|
|
def preprocess(image):
|
|
|
|
|
image = rescale(image, scale=1 / 255.0, input_data_format=format)
|
|
|
|
|
image = normalize(image=image, mean=self.image_mean, std=self.image_norm, input_data_format=format)
|
|
|
|
|
image = np.expand_dims(image, [0])
|
|
|
|
|
image = image.transpose(0, 3, 1, 2)
|
|
|
|
|
image = torch.from_numpy(image)
|
|
|
|
|
return image
|
|
|
|
|
global_image = preprocess(global_image)
|
|
|
|
|
if resized_height > self.patch_size or resized_width > self.patch_size:
|
|
|
|
|
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
|
|
|
|
|
image = preprocess(image)
|
|
|
|
|
image, grid_h, grid_w = self.vision_reshape(image)
|
|
|
|
|
image = torch.concat([image, global_image], dim=0)
|
|
|
|
|
else:
|
|
|
|
|
grid_h, grid_w = 0, 0
|
|
|
|
|
image = global_image
|
|
|
|
|
image_embed = self.images_forward(image)
|
|
|
|
|
num_images, img_pad_len, _, vision_hidden_size = image_embed.shape
|
|
|
|
|
self.image_embeds.append(image_embed.reshape(-1, 1, vision_hidden_size))
|
|
|
|
|
return img_pad_len, grid_h, grid_w
|
|
|
|
|
|
|
|
|
|
def embed(self, input_ids, images = None, videos = None):
|
|
|
|
|
input_embeds = self.embed_(input_ids)
|
|
|
|
|
if self.image_embeds is not None and len(self.image_embeds) > 0:
|
|
|
|
|
image_mask = (input_ids == self.image_pad_id).squeeze()
|
|
|
|
|
input_embeds[image_mask] = torch.concat(self.image_embeds, dim=0).to(input_embeds.dtype)
|
|
|
|
|
return input_embeds
|
|
|
|
|
|
|
|
|
|
@spinner_run(f'export visual to ')
|
|
|
|
|
def export(self, onnx_path):
|
|
|
|
|
pixel_values = torch.randn([1, 3, self.patch_size, self.patch_size])
|
|
|
|
|
onnx_model = f'{onnx_path}/visual.onnx'
|
|
|
|
|
torch.onnx.export(self, (pixel_values),
|
|
|
|
|
onnx_model,
|
|
|
|
|
input_names=['pixel_values'],
|
|
|
|
|
output_names=['image_embeds'],
|
|
|
|
|
dynamic_axes={
|
|
|
|
|
"pixel_values": { 0: "size" },
|
|
|
|
|
},
|
|
|
|
|
do_constant_folding=True,
|
|
|
|
|
verbose=False,
|
|
|
|
|
opset_version=15)
|
|
|
|
|
return onnx_model
|
|
|
|
|
|
|
|
|
|
# FastVLM
|
|
|
|
|
class MobileCLIPVision(QwenVision):
|
|
|
|
|
def __init__(self, visual, base):
|
|
|
|
|
super().__init__(visual, base)
|
|
|
|
|
self.visual = visual.float()
|
|
|
|
|
self.mm_projector = base.model.model.mm_projector.float()
|
|
|
|
|
self.quant_bit = 8
|
|
|
|
|
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.llm_config['is_visual'] = True
|
|
|
|
|
image_mean = np.array([0.0, 0.0, 0.0])
|
|
|
|
|
image_norm = np.array([1.0, 1.0, 1.0]) / 255.0
|
|
|
|
|
self.llm_config['image_mean'] = image_mean.tolist()
|
|
|
|
|
self.llm_config['image_norm'] = image_norm.tolist()
|
|
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
|
self.image_size = self.visual.config['image_cfg']['image_size']
|
|
|
|
|
self.image_start_id = -200
|
|
|
|
|
self.llm_config['image_size'] = self.image_size
|
|
|
|
|
self.llm_config['vision_start'] = -200
|
|
|
|
|
self.llm_config['vision_end'] = -200
|
|
|
|
|
self.llm_config['image_pad'] = -200
|
|
|
|
|
|
|
|
|
|
def forward(self, images):
|
|
|
|
|
image_features = self.visual(images)
|
|
|
|
|
image_features = self.mm_projector(image_features)
|
|
|
|
|
image_features = image_features.permute(1, 0, 2)
|
|
|
|
|
return image_features
|