Skip to content

Commit

Permalink
whisper checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
LeshengJin committed Mar 1, 2024
1 parent 8606327 commit c88231b
Show file tree
Hide file tree
Showing 58 changed files with 7,416 additions and 25 deletions.
3 changes: 2 additions & 1 deletion python/mlc_chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
MLC Chat is the app runtime of MLC LLM.
"""
from . import protocol, serve

# from . import protocol, serve
from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig
from .libinfo import __version__
5 changes: 3 additions & 2 deletions python/mlc_chat/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Load MLC LLM library and _ffi_api functions."""

import ctypes
import os
import sys
Expand All @@ -24,5 +25,5 @@ def _load_mlc_llm_lib():


# only load once here
if SKIP_LOADING_MLCLLM_SO == "0":
_LIB, _LIB_PATH = _load_mlc_llm_lib()
# if SKIP_LOADING_MLCLLM_SO == "0":
# _LIB, _LIB_PATH = _load_mlc_llm_lib()
1 change: 1 addition & 0 deletions python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The compilation pipeline for LLM applications."""

from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down
1 change: 1 addition & 0 deletions python/mlc_chat/interface/compile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Python entrypoint of compilation."""

import dataclasses
import math
from io import StringIO
Expand Down
14 changes: 14 additions & 0 deletions python/mlc_chat/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .phi import phi_loader, phi_model, phi_quantization
from .qwen import qwen_loader, qwen_model, qwen_quantization
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
from .whisper import whisper_loader, whisper_model, whisper_quantization

ModelConfig = Any
"""A ModelConfig is an object that represents a model architecture. It is required to have
Expand Down Expand Up @@ -195,4 +196,17 @@ class Model:
"group-quant": stablelm_quantization.group_quant,
},
),
"whisper": Model(
name="whisper",
model=whisper_model.WhisperForConditionalGeneration,
config=whisper_model.WhisperConfig,
source={
"huggingface-torch": whisper_loader.huggingface,
"huggingface-safetensor": whisper_loader.huggingface,
},
quantize={
"no-quant": whisper_quantization.no_quant,
"group-quant": whisper_quantization.group_quant,
},
),
}
Empty file.
51 changes: 51 additions & 0 deletions python/mlc_chat/model/whisper/whisper_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
This file specifies how MLC's Whisper parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

from mlc_chat.loader import ExternMapping
from mlc_chat.quantization import Quantization

from .whisper_model import WhisperConfig, WhisperForConditionalGeneration


def huggingface(model_config: WhisperConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Parameters
----------
model_config : WhisperConfig
The configuration of the GPTNeoX model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = WhisperForConditionalGeneration(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for mlc_name, mlc_param in named_parameters.items():
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading

0 comments on commit c88231b

Please sign in to comment.