Skip to content

Commit

Permalink
Merge pull request NVlabs#3 from XueFuzhao/main_vila_branch
Browse files Browse the repository at this point in the history
Add SigLIP and Rope Scaling
  • Loading branch information
XueFuzhao authored Jan 20, 2024
2 parents 79fe8a7 + b564c38 commit b1469bd
Show file tree
Hide file tree
Showing 24 changed files with 3,464 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dist

# Data
!**/alpaca-data-conversation.json
eval/*

# Editor
.idea
Expand Down
15 changes: 12 additions & 3 deletions llava/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import torch
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel
from transformers.models.siglip import (
SiglipVisionModel,
SiglipImageProcessor,
)

from llava.conversation import SeparatorStyle, conv_templates
from llava.model import LlavaLlamaForCausalLM
Expand Down Expand Up @@ -43,9 +47,14 @@ def build_model(model_name, conv_version):
).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

image_processor = CLIPImageProcessor.from_pretrained(
model.config.mm_vision_tower, torch_dtype=torch.float16
)
if "siglip" in model_name:
image_processor = SiglipImageProcessor.from_pretrained(
model.config.mm_vision_tower, torch_dtype=torch.float16
)
else:
image_processor = CLIPImageProcessor.from_pretrained(
model.config.mm_vision_tower, torch_dtype=torch.float16
)

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
# assert mm_use_im_start_end
Expand Down
25 changes: 24 additions & 1 deletion llava/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
LlamaForCausalLM,
LlamaModel,
)

from transformers.models.siglip import (
SiglipVisionModel,
SiglipImageProcessor,
)

from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -113,6 +119,11 @@ def __init__(self, config: LlamaConfig):
"patch_size": 14,
}
)
elif self.vision_tower_class == "siglip":
self.vision_tower = [
SiglipVisionModel.from_pretrained(config.mm_vision_tower)
]
vision_config = self.vision_tower[0].config
else:
self.vision_tower = [
CLIPVisionModel.from_pretrained(config.mm_vision_tower)
Expand Down Expand Up @@ -142,6 +153,8 @@ def vision_tower_class(self):
vision_tower_arch = "eva"
elif "raw" in self.config.mm_vision_tower.lower():
vision_tower_arch = "raw"
elif "siglip" in self.config.mm_vision_tower.lower():
vision_tower_arch = "siglip"
else:
vision_tower_arch = "clip"
return vision_tower_arch
Expand All @@ -167,7 +180,14 @@ def initialize_vision_modules(
add_visual_expert_attn=self.config.add_visual_expert_attn,
)

image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
print(" * Loading vision tower from", vision_tower)
print(" * Using vision tower class", self.vision_tower_class)
print(" hasattr vision_tower: ", hasattr(self, "vision_tower"))
if hasattr(self, "vision_tower") and self.vision_tower_class == "siglip":
image_processor = SiglipImageProcessor.from_pretrained(vision_tower)
else:
image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
print(" * Using image processor", image_processor)

if not hasattr(self, "vision_tower"):
if self.vision_tower_class == "qwen":
Expand Down Expand Up @@ -227,6 +247,9 @@ def initialize_vision_modules(
"patch_size": 14,
}
)
elif self.vision_tower_class == "siglip":
vision_tower = SiglipVisionModel.from_pretrained(vision_tower)
vision_config = vision_tower.config

else:
vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
Expand Down
1 change: 1 addition & 0 deletions llava/train/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
trust_remote_code: bool = field(default=True)
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
tune_vision_encoder: bool = field(default=False)
Expand Down
12 changes: 8 additions & 4 deletions llava/train/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def _process_image(image_file, multimodal_cfg: dict, resize=False):

image_folder = multimodal_cfg["image_folder"]
processor = multimodal_cfg["image_processor"]
image_size = multimodal_cfg["image_size"]
if isinstance(image_file, str):
if image_folder is not None:
image_file = os.path.join(image_folder, image_file)
Expand All @@ -328,7 +329,7 @@ def _process_image(image_file, multimodal_cfg: dict, resize=False):
image = image.resize((30, 30))

if resize:
image = image.resize((336, 336))
image = image.resize((image_size, image_size))
if multimodal_cfg["image_aspect_ratio"] == "keep":
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
Expand Down Expand Up @@ -366,18 +367,19 @@ def expand2square(pil_img, background_color):
def load_video(self, video_path, num_video_frames):
decord.bridge.set_bridge("torch")
video_reader = VideoReader(uri=video_path)
image_size = self.multimodal_cfg["image_size"]

idx = np.round(np.linspace(0, len(video_reader) - 1, num_video_frames)).astype(int)
try:
video_outputs = video_reader.get_batch(idx)
except:
print(f'bad data path {video_path}')
video_outputs = torch.zeros(8, 336, 336, 3, dtype=torch.uint8)
video_outputs = torch.zeros(8, image_size, image_size, 3, dtype=torch.uint8)

b, h, w, c = video_outputs.size()
image_tensor = torch.zeros(b, c, 336, 336, dtype=torch.uint8)
image_tensor = torch.zeros(b, c, image_size, image_size, dtype=torch.uint8)
video_frames = video_outputs.permute(0, 3, 1, 2).contiguous()
video_frames = Resize(size=[336, 336], antialias=True)(video_frames)
video_frames = Resize(size=[image_size, image_size], antialias=True)(video_frames)
image_tensor[:, :, :, :] = video_frames

return image_tensor
Expand Down Expand Up @@ -1474,6 +1476,7 @@ def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
patch_size,
image_size,
n_extra_patch=0,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
Expand Down Expand Up @@ -1526,6 +1529,7 @@ def make_supervised_data_module(
use_im_start_end=getattr(data_args, "mm_use_im_start_end", False),
image_processor=getattr(data_args, "image_processor", None),
patch_size=patch_size,
image_size=image_size,
n_extra_patch=n_extra_patch,
),
)
Expand Down
22 changes: 22 additions & 0 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import shutil
import time
import math
from typing import Dict

import torch
Expand Down Expand Up @@ -103,6 +104,7 @@ def train():
'mixtral' in model_args.model_name_or_path):
model_cls = LlavaMixtralForCausalLM


if model_args.vision_tower is not None:
# NOTE: a temporay hack to address the CPU OOM problem during model loading
if "70" in model_args.model_name_or_path:
Expand All @@ -114,8 +116,20 @@ def train():

time.sleep(300)

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}

model = model_cls.from_pretrained(
model_args.model_name_or_path,
config=config,
# low_cpu_mem_usage="70" in model_args.model_name_or_path,
)
else:
Expand Down Expand Up @@ -364,6 +378,13 @@ def wrap_func(*args, **kwargs):
and "eva" not in str(type(model.get_vision_tower())).lower()
):
patch_size = 28 # qwen
elif "siglip" in str(type(model.get_vision_tower())).lower():
if "16" in model_args.vision_tower:
patch_size = 16
elif "so400m" in model_args.vision_tower:
patch_size = 14
else:
raise ValueError("Unknown siglip model, please set the patch size")
else: # clip or eva
patch_size = 14
patch_size = patch_size * 2 ** model_args.mm_projector_type.count("ds")
Expand All @@ -372,6 +393,7 @@ def wrap_func(*args, **kwargs):
tokenizer=tokenizer,
data_args=data_args,
patch_size=patch_size,
image_size=vision_config.image_size,
n_extra_patch=n_extra_patch,
)

Expand Down
94 changes: 94 additions & 0 deletions llava/train/transformers_replace/siglip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_vision_available,
)


_import_structure = {
"configuration_siglip": [
"SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SiglipConfig",
"SiglipTextConfig",
"SiglipVisionConfig",
],
"processing_siglip": ["SiglipProcessor"],
"tokenization_siglip": ["SiglipTokenizer"],
}

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_siglip"] = [
"SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"SiglipModel",
"SiglipPreTrainedModel",
"SiglipTextModel",
"SiglipVisionModel",
]


if TYPE_CHECKING:
from .configuration_siglip import (
SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
SiglipConfig,
SiglipTextConfig,
SiglipVisionConfig,
)
from .processing_siglip import SiglipProcessor
from .tokenization_siglip import SiglipTokenizer

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_siglip import SiglipImageProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_siglip import (
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
SiglipModel,
SiglipPreTrainedModel,
SiglipTextModel,
SiglipVisionModel,
)


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading

0 comments on commit b1469bd

Please sign in to comment.