diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml
index 9043792..3458175 100644
--- a/.github/workflows/actions.yml
+++ b/.github/workflows/actions.yml
@@ -21,35 +21,57 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.9'
- - uses: pre-commit/action@v3.0.1
+ - name: Lint
+ uses: pre-commit/action@v3.0.1
+ - name: Get pip cache dir
+ id: pip-cache
+ run: |
+ python -m pip install --upgrade pip setuptools
+ echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
+ - name: Cache pip
+ uses: actions/cache@v4
+ with:
+ path: ${{ steps.pip-cache.outputs.dir }}
+ key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
+ - name: Install dependencies
+ run: |
+ pip install -r requirements.txt --progress-bar off --upgrade
+ pip install -e ".[tests]" --progress-bar off --upgrade
+ - name: Check for API changes
+ run: |
+ bash shell/api_gen.sh
+ git status
+ clean=$(git status | grep "nothing to commit")
+ if [ -z "$clean" ]; then
+ echo "Please run shell/api_gen.sh to generate API."
+ exit 1
+ fi
build:
strategy:
fail-fast: false
matrix:
- python-version: [3.9]
backend: [tensorflow, jax, torch, numpy]
name: Run tests
runs-on: ubuntu-latest
env:
- PYTHON: ${{ matrix.python-version }}
KERAS_BACKEND: ${{ matrix.backend }}
steps:
- uses: actions/checkout@v4
- - name: Set up Python
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: ${{ matrix.python-version }}
+ python-version: '3.9'
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- - name: Pip cache
+ - name: Cache pip
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
- key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
+ key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 516d50c..7855983 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -23,11 +23,11 @@ jobs:
run: |
python -m pip install --upgrade pip setuptools
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- - name: Pip cache
+ - name: Cache pip
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
- key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
+ key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5d5d8cb..23893c5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -17,7 +17,6 @@ repos:
rev: 5.13.2
hooks:
- id: isort
- name: isort (python)
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
diff --git a/README.md b/README.md
index d62b5d6..479ce41 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
-[![Keras](https://img.shields.io/badge/keras-v3.0.4+-success.svg)](https://github.com/keras-team/keras)
+[![Keras](https://img.shields.io/badge/keras-v3.3.0+-success.svg)](https://github.com/keras-team/keras)
[![PyPI](https://img.shields.io/pypi/v/kimm)](https://pypi.org/project/kimm/)
[![Contributions Welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/james77777778/kimm/issues)
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/james77777778/keras-image-models/actions.yml?label=tests)](https://github.com/james77777778/keras-image-models/actions/workflows/actions.yml?query=branch%3Amain++)
diff --git a/api_gen.py b/api_gen.py
new file mode 100644
index 0000000..9bc27ae
--- /dev/null
+++ b/api_gen.py
@@ -0,0 +1,13 @@
+import namex
+
+from kimm._src.version import __version__
+
+namex.generate_api_files(package="kimm", code_directory="_src")
+
+# Add version string
+
+with open("kimm/__init__.py", "r") as f:
+ contents = f.read()
+with open("kimm/__init__.py", "w") as f:
+ contents += f'__version__ = "{__version__}"\n'
+ f.write(contents)
diff --git a/kimm/__init__.py b/kimm/__init__.py
index 18a5301..0f1ffcd 100644
--- a/kimm/__init__.py
+++ b/kimm/__init__.py
@@ -1,6 +1,16 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm import blocks
from kimm import export
+from kimm import layers
from kimm import models
+from kimm import timm_utils
from kimm import utils
-from kimm.utils.model_registry import list_models
+from kimm._src.utils.model_registry import list_models
+from kimm._src.version import version
__version__ = "0.2.0"
diff --git a/kimm/_src/blocks/__init__.py b/kimm/_src/blocks/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/kimm/blocks/base_block.py b/kimm/_src/blocks/conv2d.py
similarity index 52%
rename from kimm/blocks/base_block.py
rename to kimm/_src/blocks/conv2d.py
index 9091c7a..62c3d19 100644
--- a/kimm/blocks/base_block.py
+++ b/kimm/_src/blocks/conv2d.py
@@ -3,25 +3,10 @@
from keras import backend
from keras import layers
-from kimm.utils import make_divisible
-
-
-def apply_activation(
- inputs, activation: typing.Optional[str] = None, name: str = "activation"
-):
- x = inputs
- if activation is not None:
- if isinstance(activation, str):
- x = layers.Activation(activation, name=name)(x)
- elif isinstance(activation, layers.Layer):
- x = activation(x)
- else:
- NotImplementedError(
- f"Unsupported activation type: {type(activation)}"
- )
- return x
+from kimm._src.kimm_export import kimm_export
+@kimm_export(parent_path=["kimm.blocks"])
def apply_conv2d_block(
inputs,
filters: typing.Optional[int] = None,
@@ -83,45 +68,8 @@ def apply_conv2d_block(
momentum=bn_momentum,
epsilon=bn_epsilon,
)(x)
- x = apply_activation(x, activation, name=name)
+ if activation is not None:
+ x = layers.Activation(activation, name=name)(x)
if has_skip:
x = layers.Add()([x, inputs])
return x
-
-
-def apply_se_block(
- inputs,
- se_ratio: float = 0.25,
- activation: typing.Optional[str] = "relu",
- gate_activation: typing.Optional[str] = "sigmoid",
- make_divisible_number: typing.Optional[int] = None,
- se_input_channels: typing.Optional[int] = None,
- name: str = "se_block",
-):
- channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
- input_channels = inputs.shape[channels_axis]
- if se_input_channels is None:
- se_input_channels = input_channels
- if make_divisible_number is None:
- se_channels = round(se_input_channels * se_ratio)
- else:
- se_channels = make_divisible(
- se_input_channels * se_ratio, make_divisible_number
- )
-
- x = inputs
- x = layers.GlobalAveragePooling2D(
- data_format=backend.image_data_format(),
- keepdims=True,
- name=f"{name}_mean",
- )(x)
- x = layers.Conv2D(
- se_channels, 1, use_bias=True, name=f"{name}_conv_reduce"
- )(x)
- x = apply_activation(x, activation, name=f"{name}_act1")
- x = layers.Conv2D(
- input_channels, 1, use_bias=True, name=f"{name}_conv_expand"
- )(x)
- x = apply_activation(x, gate_activation, name=f"{name}_gate")
- x = layers.Multiply(name=name)([inputs, x])
- return x
diff --git a/kimm/blocks/depthwise_separation_block.py b/kimm/_src/blocks/depthwise_separation.py
similarity index 89%
rename from kimm/blocks/depthwise_separation_block.py
rename to kimm/_src/blocks/depthwise_separation.py
index a70ecee..fbbefb5 100644
--- a/kimm/blocks/depthwise_separation_block.py
+++ b/kimm/_src/blocks/depthwise_separation.py
@@ -3,10 +3,12 @@
from keras import backend
from keras import layers
-from kimm.blocks.base_block import apply_conv2d_block
-from kimm.blocks.base_block import apply_se_block
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.squeeze_and_excitation import apply_se_block
+from kimm._src.kimm_export import kimm_export
+@kimm_export(parent_path=["kimm.blocks"])
def apply_depthwise_separation_block(
inputs,
output_channels: int,
diff --git a/kimm/blocks/inverted_residual_block.py b/kimm/_src/blocks/inverted_residual.py
similarity index 89%
rename from kimm/blocks/inverted_residual_block.py
rename to kimm/_src/blocks/inverted_residual.py
index b5dc95c..46cbe72 100644
--- a/kimm/blocks/inverted_residual_block.py
+++ b/kimm/_src/blocks/inverted_residual.py
@@ -3,11 +3,13 @@
from keras import backend
from keras import layers
-from kimm.blocks.base_block import apply_conv2d_block
-from kimm.blocks.base_block import apply_se_block
-from kimm.utils import make_divisible
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.squeeze_and_excitation import apply_se_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.utils.make_divisble import make_divisible
+@kimm_export(parent_path=["kimm.blocks"])
def apply_inverted_residual_block(
inputs,
output_channels: int,
diff --git a/kimm/_src/blocks/squeeze_and_excitation.py b/kimm/_src/blocks/squeeze_and_excitation.py
new file mode 100644
index 0000000..8a1cef0
--- /dev/null
+++ b/kimm/_src/blocks/squeeze_and_excitation.py
@@ -0,0 +1,48 @@
+import typing
+
+from keras import backend
+from keras import layers
+
+from kimm._src.kimm_export import kimm_export
+from kimm._src.utils.make_divisble import make_divisible
+
+
+@kimm_export(parent_path=["kimm.blocks"])
+def apply_se_block(
+ inputs,
+ se_ratio: float = 0.25,
+ activation: typing.Optional[str] = "relu",
+ gate_activation: typing.Optional[str] = "sigmoid",
+ make_divisible_number: typing.Optional[int] = None,
+ se_input_channels: typing.Optional[int] = None,
+ name: str = "se_block",
+):
+ channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
+ input_channels = inputs.shape[channels_axis]
+ if se_input_channels is None:
+ se_input_channels = input_channels
+ if make_divisible_number is None:
+ se_channels = round(se_input_channels * se_ratio)
+ else:
+ se_channels = make_divisible(
+ se_input_channels * se_ratio, make_divisible_number
+ )
+
+ x = inputs
+ x = layers.GlobalAveragePooling2D(
+ data_format=backend.image_data_format(),
+ keepdims=True,
+ name=f"{name}_mean",
+ )(x)
+ x = layers.Conv2D(
+ se_channels, 1, use_bias=True, name=f"{name}_conv_reduce"
+ )(x)
+ if activation is not None:
+ x = layers.Activation(activation, name=f"{name}_act1")(x)
+ x = layers.Conv2D(
+ input_channels, 1, use_bias=True, name=f"{name}_conv_expand"
+ )(x)
+ if activation is not None:
+ x = layers.Activation(gate_activation, name=f"{name}_gate")(x)
+ x = layers.Multiply(name=name)([inputs, x])
+ return x
diff --git a/kimm/blocks/transformer_block.py b/kimm/_src/blocks/transformer.py
similarity index 91%
rename from kimm/blocks/transformer_block.py
rename to kimm/_src/blocks/transformer.py
index 42bb60b..d74cc73 100644
--- a/kimm/blocks/transformer_block.py
+++ b/kimm/_src/blocks/transformer.py
@@ -3,9 +3,11 @@
from keras import backend
from keras import layers
-from kimm import layers as kimm_layers
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.attention import Attention
+@kimm_export(parent_path=["kimm.blocks"])
def apply_mlp_block(
inputs,
hidden_dim: int,
@@ -42,6 +44,7 @@ def apply_mlp_block(
return x
+@kimm_export(parent_path=["kimm.blocks"])
def apply_transformer_block(
inputs,
dim: int,
@@ -58,7 +61,7 @@ def apply_transformer_block(
residual_1 = x
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x)
- x = kimm_layers.Attention(
+ x = Attention(
dim,
num_heads,
use_qkv_bias,
diff --git a/kimm/_src/export/__init__.py b/kimm/_src/export/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/kimm/export/export_onnx.py b/kimm/_src/export/export_onnx.py
similarity index 93%
rename from kimm/export/export_onnx.py
rename to kimm/_src/export/export_onnx.py
index 28a2bd9..fd229f2 100644
--- a/kimm/export/export_onnx.py
+++ b/kimm/_src/export/export_onnx.py
@@ -6,10 +6,12 @@
from keras import models
from keras import ops
-from kimm.models import BaseModel
-from kimm.utils.module_utils import torch
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.module_utils import torch
+@kimm_export(parent_path=["kimm.export"])
def export_onnx(
model: BaseModel,
input_shape: typing.Union[int, typing.Sequence[int]],
diff --git a/kimm/export/export_onnx_test.py b/kimm/_src/export/export_onnx_test.py
similarity index 78%
rename from kimm/export/export_onnx_test.py
rename to kimm/_src/export/export_onnx_test.py
index 35e0d28..8bce75b 100644
--- a/kimm/export/export_onnx_test.py
+++ b/kimm/_src/export/export_onnx_test.py
@@ -3,14 +3,16 @@
from keras import backend
from keras.src import testing
-from kimm import export
-from kimm import models
+from kimm._src import models
+from kimm._src.export import export_onnx
class ExportOnnxTest(testing.TestCase, parameterized.TestCase):
def get_model(self):
input_shape = [3, 224, 224] # channels_first
- model = models.MobileNetV3W050Small(include_preprocessing=False)
+ model = models.mobilenet_v3.MobileNetV3W050Small(
+ include_preprocessing=False, weights=None
+ )
return input_shape, model
@classmethod
@@ -33,4 +35,4 @@ def DISABLE_test_export_onnx_use(self):
temp_dir = self.get_temp_dir()
- export.export_onnx(model, input_shape, f"{temp_dir}/model.onnx")
+ export_onnx.export_onnx(model, input_shape, f"{temp_dir}/model.onnx")
diff --git a/kimm/export/export_tflite.py b/kimm/_src/export/export_tflite.py
similarity index 87%
rename from kimm/export/export_tflite.py
rename to kimm/_src/export/export_tflite.py
index f735b4b..d5c04cf 100644
--- a/kimm/export/export_tflite.py
+++ b/kimm/_src/export/export_tflite.py
@@ -7,9 +7,11 @@
from keras import models
from keras.src.utils.module_utils import tensorflow as tf
-from kimm.models import BaseModel
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+@kimm_export(parent_path=["kimm.export"])
def export_tflite(
model: BaseModel,
input_shape: typing.Union[int, typing.Sequence[int]],
@@ -20,9 +22,10 @@ def export_tflite(
):
"""Export the model to tflite format.
- Only tensorflow backend with 'channels_last' is supported. The tflite model
- will be generated using `tf.lite.TFLiteConverter.from_saved_model` and
- optimized through tflite built-in functions.
+ Only TensorFlow backend with 'channels_last' is supported. The
+ tflite model will be generated using
+ `tf.lite.TFLiteConverter.from_saved_model` and optimized through tflite
+ built-in functions.
Note that when exporting an `int8` tflite model, `representative_dataset`
must be passed.
@@ -37,8 +40,8 @@ def export_tflite(
batch_size: int, specifying the batch size of the input,
defaults to `1`.
"""
- if backend.backend() != "tensorflow":
- raise ValueError("`export_tflite` only supports tensorflow backend")
+ if backend.backend() not in ("tensorflow",):
+ raise ValueError("`export_tflite` only supports TensorFlow backend")
if backend.image_data_format() != "channels_last":
raise ValueError(
"`export_tflite` only supports 'channels_last' data format."
diff --git a/kimm/export/export_tflite_test.py b/kimm/_src/export/export_tflite_test.py
similarity index 87%
rename from kimm/export/export_tflite_test.py
rename to kimm/_src/export/export_tflite_test.py
index 15d2c24..153f442 100644
--- a/kimm/export/export_tflite_test.py
+++ b/kimm/_src/export/export_tflite_test.py
@@ -5,14 +5,16 @@
from keras import random
from keras.src import testing
-from kimm import export
-from kimm import models
+from kimm._src import models
+from kimm._src.export import export_tflite
class ExportTFLiteTest(testing.TestCase, parameterized.TestCase):
def get_model_and_representative_dataset(self):
input_shape = [224, 224, 3]
- model = models.MobileNetV3W050Small(include_preprocessing=False)
+ model = models.mobilenet_v3.MobileNetV3W050Small(
+ include_preprocessing=False, weights=None
+ )
def representative_dataset():
for _ in range(10):
@@ -39,7 +41,7 @@ def test_export_tflite_fp32(self):
(input_shape, model, _) = self.get_model_and_representative_dataset()
temp_dir = self.get_temp_dir()
- export.export_tflite(
+ export_tflite.export_tflite(
model, input_shape, f"{temp_dir}/model_fp32.onnx", "float32"
)
@@ -50,7 +52,7 @@ def test_export_tflite_fp16(self):
(input_shape, model, _) = self.get_model_and_representative_dataset()
temp_dir = self.get_temp_dir()
- export.export_tflite(
+ export_tflite.export_tflite(
model, input_shape, f"{temp_dir}/model_fp16.tflite", "float16"
)
@@ -65,7 +67,7 @@ def test_export_tflite_int8(self):
) = self.get_model_and_representative_dataset()
temp_dir = self.get_temp_dir()
- export.export_tflite(
+ export_tflite.export_tflite(
model,
input_shape,
f"{temp_dir}/model_int8.tflite",
diff --git a/kimm/_src/kimm_export.py b/kimm/_src/kimm_export.py
new file mode 100644
index 0000000..0c2929c
--- /dev/null
+++ b/kimm/_src/kimm_export.py
@@ -0,0 +1,87 @@
+try:
+ import namex
+except ImportError:
+ namex = None
+
+# These dicts reference "canonical names" only
+# (i.e. the first name an object was registered with).
+REGISTERED_NAMES_TO_OBJS = {}
+REGISTERED_OBJS_TO_NAMES = {}
+
+
+def register_internal_serializable(path, symbol):
+ global REGISTERED_NAMES_TO_OBJS
+ if isinstance(path, (list, tuple)):
+ name = path[0]
+ else:
+ name = path
+ REGISTERED_NAMES_TO_OBJS[name] = symbol
+ REGISTERED_OBJS_TO_NAMES[symbol] = name
+
+
+def get_symbol_from_name(name):
+ return REGISTERED_NAMES_TO_OBJS.get(name, None)
+
+
+def get_name_from_symbol(symbol):
+ return REGISTERED_OBJS_TO_NAMES.get(symbol, None)
+
+
+if namex:
+
+ class kimm_export:
+ def __init__(self, parent_path):
+ package = "kimm"
+
+ if isinstance(parent_path, str):
+ export_paths = [parent_path]
+ elif isinstance(parent_path, list):
+ export_paths = parent_path
+ else:
+ raise ValueError(
+ f"Invalid type for `parent_path` argument: "
+ f"Received '{parent_path}' "
+ f"of type {type(parent_path)}"
+ )
+ for p in export_paths:
+ if not p.startswith(package):
+ raise ValueError(
+ "All `export_path` values should start with "
+ f"'{package}.'. Received: parent_path={parent_path}"
+ )
+ self.package = package
+ self.parent_path = parent_path
+
+ def __call__(self, symbol):
+ if hasattr(symbol, "_api_export_path") and (
+ symbol._api_export_symbol_id == id(symbol)
+ ):
+ raise ValueError(
+ f"Symbol {symbol} is already exported as "
+ f"'{symbol._api_export_path}'. "
+ f"Cannot also export it to '{self.parent_path}'."
+ )
+ if isinstance(self.parent_path, list):
+ path = [p + f".{symbol.__name__}" for p in self.parent_path]
+ elif isinstance(self.parent_path, str):
+ path = self.parent_path + f".{symbol.__name__}"
+ symbol._api_export_path = path
+ symbol._api_export_symbol_id = id(symbol)
+
+ register_internal_serializable(path, symbol)
+ return symbol
+
+else:
+
+ class kimm_export:
+ def __init__(self, parent_path):
+ self.parent_path = parent_path
+
+ def __call__(self, symbol):
+ if isinstance(self.parent_path, list):
+ path = [p + f".{symbol.__name__}" for p in self.parent_path]
+ elif isinstance(self.parent_path, str):
+ path = self.parent_path + f".{symbol.__name__}"
+
+ register_internal_serializable(path, symbol)
+ return symbol
diff --git a/kimm/_src/layers/__init__.py b/kimm/_src/layers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/kimm/layers/attention.py b/kimm/_src/layers/attention.py
similarity index 97%
rename from kimm/layers/attention.py
rename to kimm/_src/layers/attention.py
index 271f10a..0783db3 100644
--- a/kimm/layers/attention.py
+++ b/kimm/_src/layers/attention.py
@@ -2,7 +2,10 @@
from keras import layers
from keras import ops
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class Attention(layers.Layer):
def __init__(
diff --git a/kimm/layers/attention_test.py b/kimm/_src/layers/attention_test.py
similarity index 92%
rename from kimm/layers/attention_test.py
rename to kimm/_src/layers/attention_test.py
index 12a6992..8f65582 100644
--- a/kimm/layers/attention_test.py
+++ b/kimm/_src/layers/attention_test.py
@@ -2,7 +2,7 @@
from absl.testing import parameterized
from keras.src import testing
-from kimm.layers.attention import Attention
+from kimm._src.layers.attention import Attention
class AttentionTest(testing.TestCase, parameterized.TestCase):
diff --git a/kimm/layers/layer_scale.py b/kimm/_src/layers/layer_scale.py
similarity index 90%
rename from kimm/layers/layer_scale.py
rename to kimm/_src/layers/layer_scale.py
index 8cb2924..3aa1d1b 100644
--- a/kimm/layers/layer_scale.py
+++ b/kimm/_src/layers/layer_scale.py
@@ -2,15 +2,17 @@
from keras import initializers
from keras import layers
from keras import ops
-from keras.initializers import Initializer
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class LayerScale(layers.Layer):
def __init__(
self,
axis: int = -1,
- initializer: Initializer = initializers.Constant(1e-5),
+ initializer: initializers.Initializer = initializers.Constant(1e-5),
**kwargs,
):
super().__init__(**kwargs)
@@ -29,8 +31,6 @@ def build(self, input_shape):
self.built = True
def call(self, inputs, training=None, mask=None):
- inputs = ops.cast(inputs, self.compute_dtype)
-
# Broadcasting only necessary for norm when the axis is not just
# the last dimension
input_shape = inputs.shape
@@ -40,7 +40,6 @@ def call(self, inputs, training=None, mask=None):
broadcast_shape[dim] = input_shape[dim]
gamma = ops.reshape(self.gamma, broadcast_shape)
gamma = ops.cast(gamma, self.compute_dtype)
-
return ops.multiply(inputs, gamma)
def get_config(self):
diff --git a/kimm/layers/layer_scale_test.py b/kimm/_src/layers/layer_scale_test.py
similarity index 91%
rename from kimm/layers/layer_scale_test.py
rename to kimm/_src/layers/layer_scale_test.py
index 6344923..2523137 100644
--- a/kimm/layers/layer_scale_test.py
+++ b/kimm/_src/layers/layer_scale_test.py
@@ -2,7 +2,7 @@
from absl.testing import parameterized
from keras.src import testing
-from kimm.layers.layer_scale import LayerScale
+from kimm._src.layers.layer_scale import LayerScale
class LayerScaleTest(testing.TestCase, parameterized.TestCase):
diff --git a/kimm/layers/learnable_affine.py b/kimm/_src/layers/learnable_affine.py
similarity index 94%
rename from kimm/layers/learnable_affine.py
rename to kimm/_src/layers/learnable_affine.py
index 6e41b3f..909637c 100644
--- a/kimm/layers/learnable_affine.py
+++ b/kimm/_src/layers/learnable_affine.py
@@ -2,7 +2,10 @@
from keras import layers
from keras import ops
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class LearnableAffine(layers.Layer):
def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs):
diff --git a/kimm/layers/learnable_affine_test.py b/kimm/_src/layers/learnable_affine_test.py
similarity index 90%
rename from kimm/layers/learnable_affine_test.py
rename to kimm/_src/layers/learnable_affine_test.py
index 8aa2335..0b23e09 100644
--- a/kimm/layers/learnable_affine_test.py
+++ b/kimm/_src/layers/learnable_affine_test.py
@@ -2,7 +2,7 @@
from absl.testing import parameterized
from keras.src import testing
-from kimm.layers.learnable_affine import LearnableAffine
+from kimm._src.layers.learnable_affine import LearnableAffine
class LearnableAffineTest(testing.TestCase, parameterized.TestCase):
diff --git a/kimm/layers/mobile_one_conv2d.py b/kimm/_src/layers/mobile_one_conv2d.py
similarity index 99%
rename from kimm/layers/mobile_one_conv2d.py
rename to kimm/_src/layers/mobile_one_conv2d.py
index 138eadd..df32aa2 100644
--- a/kimm/layers/mobile_one_conv2d.py
+++ b/kimm/_src/layers/mobile_one_conv2d.py
@@ -9,7 +9,10 @@
from keras.src.layers import Layer
from keras.src.utils.argument_validation import standardize_tuple
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class MobileOneConv2D(Layer):
def __init__(
diff --git a/kimm/layers/mobile_one_conv2d_test.py b/kimm/_src/layers/mobile_one_conv2d_test.py
similarity index 98%
rename from kimm/layers/mobile_one_conv2d_test.py
rename to kimm/_src/layers/mobile_one_conv2d_test.py
index f7ecc83..abdfbf2 100644
--- a/kimm/layers/mobile_one_conv2d_test.py
+++ b/kimm/_src/layers/mobile_one_conv2d_test.py
@@ -4,7 +4,7 @@
from keras import random
from keras.src import testing
-from kimm.layers.mobile_one_conv2d import MobileOneConv2D
+from kimm._src.layers.mobile_one_conv2d import MobileOneConv2D
TEST_CASES = [
{
diff --git a/kimm/layers/position_embedding.py b/kimm/_src/layers/position_embedding.py
similarity index 93%
rename from kimm/layers/position_embedding.py
rename to kimm/_src/layers/position_embedding.py
index 9738aaa..b9bc4de 100644
--- a/kimm/layers/position_embedding.py
+++ b/kimm/_src/layers/position_embedding.py
@@ -2,7 +2,10 @@
from keras import layers
from keras import ops
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class PositionEmbedding(layers.Layer):
def __init__(self, **kwargs):
diff --git a/kimm/layers/position_embedding_test.py b/kimm/_src/layers/position_embedding_test.py
similarity index 93%
rename from kimm/layers/position_embedding_test.py
rename to kimm/_src/layers/position_embedding_test.py
index f6fd8ff..db31a77 100644
--- a/kimm/layers/position_embedding_test.py
+++ b/kimm/_src/layers/position_embedding_test.py
@@ -3,7 +3,7 @@
from keras import layers
from keras.src import testing
-from kimm.layers.position_embedding import PositionEmbedding
+from kimm._src.layers.position_embedding import PositionEmbedding
class PositionEmbeddingTest(testing.TestCase, parameterized.TestCase):
diff --git a/kimm/layers/rep_conv2d.py b/kimm/_src/layers/rep_conv2d.py
similarity index 99%
rename from kimm/layers/rep_conv2d.py
rename to kimm/_src/layers/rep_conv2d.py
index e3b67e8..cc223d9 100644
--- a/kimm/layers/rep_conv2d.py
+++ b/kimm/_src/layers/rep_conv2d.py
@@ -7,7 +7,10 @@
from keras.src.layers import Layer
from keras.src.utils.argument_validation import standardize_tuple
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class RepConv2D(Layer):
def __init__(
diff --git a/kimm/layers/rep_conv2d_test.py b/kimm/_src/layers/rep_conv2d_test.py
similarity index 98%
rename from kimm/layers/rep_conv2d_test.py
rename to kimm/_src/layers/rep_conv2d_test.py
index 7128f6a..e43323d 100644
--- a/kimm/layers/rep_conv2d_test.py
+++ b/kimm/_src/layers/rep_conv2d_test.py
@@ -4,7 +4,7 @@
from keras import random
from keras.src import testing
-from kimm.layers.rep_conv2d import RepConv2D
+from kimm._src.layers.rep_conv2d import RepConv2D
TEST_CASES = [
{
diff --git a/kimm/_src/models/__init__.py b/kimm/_src/models/__init__.py
new file mode 100644
index 0000000..7c32b20
--- /dev/null
+++ b/kimm/_src/models/__init__.py
@@ -0,0 +1,19 @@
+from kimm._src.models import convmixer
+from kimm._src.models import convnext
+from kimm._src.models import densenet
+from kimm._src.models import efficientnet
+from kimm._src.models import ghostnet
+from kimm._src.models import hgnet
+from kimm._src.models import inception_next
+from kimm._src.models import inception_v3
+from kimm._src.models import mobilenet_v2
+from kimm._src.models import mobilenet_v3
+from kimm._src.models import mobileone
+from kimm._src.models import mobilevit
+from kimm._src.models import regnet
+from kimm._src.models import repvgg
+from kimm._src.models import resnet
+from kimm._src.models import vgg
+from kimm._src.models import vision_transformer
+from kimm._src.models import xception
+from kimm._src.models.base_model import BaseModel
diff --git a/kimm/models/base_model.py b/kimm/_src/models/base_model.py
similarity index 97%
rename from kimm/models/base_model.py
rename to kimm/_src/models/base_model.py
index 6331361..cc8580c 100644
--- a/kimm/models/base_model.py
+++ b/kimm/_src/models/base_model.py
@@ -9,7 +9,10 @@
from keras import utils
from keras.src.applications import imagenet_utils
+from kimm._src.kimm_export import kimm_export
+
+@kimm_export(parent_path=["kimm.models", "kimm.models.base_model"])
class BaseModel(models.Model):
default_origin = (
"https://github.com/james77777778/kimm/releases/download/0.1.0/"
@@ -24,6 +27,7 @@ def __init__(
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
**kwargs,
):
+ _include_top = getattr(self, "_include_top", True)
if not hasattr(self, "_feature_extractor"):
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
@@ -50,7 +54,7 @@ def __init__(
)
filtered_features[k] = features[k]
# Add outputs
- if backend.is_keras_tensor(outputs):
+ if _include_top and backend.is_keras_tensor(outputs):
filtered_features["TOP"] = outputs
super().__init__(
inputs=inputs, outputs=filtered_features, **kwargs
diff --git a/kimm/models/convmixer.py b/kimm/_src/models/convmixer.py
similarity index 94%
rename from kimm/models/convmixer.py
rename to kimm/_src/models/convmixer.py
index 7a143f2..1e31a98 100644
--- a/kimm/models/convmixer.py
+++ b/kimm/_src/models/convmixer.py
@@ -4,8 +4,9 @@
from keras import backend
from keras import layers
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def apply_convmixer_block(
@@ -185,6 +186,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.convmixer"])
class ConvMixer736D32(ConvMixerVariant):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]]
available_weights = [
@@ -203,6 +205,7 @@ class ConvMixer736D32(ConvMixerVariant):
activation = "relu"
+@kimm_export(parent_path=["kimm.models", "kimm.models.convmixer"])
class ConvMixer1024D20(ConvMixerVariant):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
available_weights = [
@@ -221,6 +224,7 @@ class ConvMixer1024D20(ConvMixerVariant):
activation = "gelu"
+@kimm_export(parent_path=["kimm.models", "kimm.models.convmixer"])
class ConvMixer1536D20(ConvMixerVariant):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
available_weights = [
diff --git a/kimm/models/convnext.py b/kimm/_src/models/convnext.py
similarity index 93%
rename from kimm/models/convnext.py
rename to kimm/_src/models/convnext.py
index b3d26cc..9145389 100644
--- a/kimm/models/convnext.py
+++ b/kimm/_src/models/convnext.py
@@ -5,10 +5,11 @@
from keras import initializers
from keras import layers
-from kimm import layers as kimm_layers
-from kimm.blocks import apply_mlp_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.transformer import apply_mlp_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.layer_scale import LayerScale
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def apply_convnext_block(
@@ -61,7 +62,7 @@ def apply_convnext_block(
)
# LayerScale
- x = kimm_layers.LayerScale(
+ x = LayerScale(
axis=channels_axis,
initializer=initializers.Constant(1e-6),
name=f"{name}_layerscale",
@@ -303,6 +304,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtAtto(ConvNeXtVariant):
available_weights = [
(
@@ -321,6 +323,7 @@ class ConvNeXtAtto(ConvNeXtVariant):
use_conv_mlp = True
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtFemto(ConvNeXtVariant):
available_weights = [
(
@@ -339,6 +342,7 @@ class ConvNeXtFemto(ConvNeXtVariant):
use_conv_mlp = True
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtPico(ConvNeXtVariant):
available_weights = [
(
@@ -357,6 +361,7 @@ class ConvNeXtPico(ConvNeXtVariant):
use_conv_mlp = True
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtNano(ConvNeXtVariant):
available_weights = [
(
@@ -375,6 +380,7 @@ class ConvNeXtNano(ConvNeXtVariant):
use_conv_mlp = True
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtTiny(ConvNeXtVariant):
available_weights = [
(
@@ -393,6 +399,7 @@ class ConvNeXtTiny(ConvNeXtVariant):
use_conv_mlp = False
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtSmall(ConvNeXtVariant):
available_weights = [
(
@@ -411,6 +418,7 @@ class ConvNeXtSmall(ConvNeXtVariant):
use_conv_mlp = False
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtBase(ConvNeXtVariant):
available_weights = [
(
@@ -429,6 +437,7 @@ class ConvNeXtBase(ConvNeXtVariant):
use_conv_mlp = False
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtLarge(ConvNeXtVariant):
available_weights = [
(
@@ -447,6 +456,7 @@ class ConvNeXtLarge(ConvNeXtVariant):
use_conv_mlp = False
+@kimm_export(parent_path=["kimm.models", "kimm.models.convnext"])
class ConvNeXtXLarge(ConvNeXtVariant):
available_weights = []
diff --git a/kimm/models/densenet.py b/kimm/_src/models/densenet.py
similarity index 94%
rename from kimm/models/densenet.py
rename to kimm/_src/models/densenet.py
index 2638236..e1ecdd2 100644
--- a/kimm/models/densenet.py
+++ b/kimm/_src/models/densenet.py
@@ -4,9 +4,10 @@
from keras import backend
from keras import layers
-from kimm.blocks import apply_conv2d_block
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def apply_dense_layer(
@@ -216,6 +217,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.densenet"])
class DenseNet121(DenseNetVariant):
available_weights = [
(
@@ -231,6 +233,7 @@ class DenseNet121(DenseNetVariant):
default_size = 288
+@kimm_export(parent_path=["kimm.models", "kimm.models.densenet"])
class DenseNet161(DenseNetVariant):
available_weights = [
(
@@ -246,6 +249,7 @@ class DenseNet161(DenseNetVariant):
default_size = 224
+@kimm_export(parent_path=["kimm.models", "kimm.models.densenet"])
class DenseNet169(DenseNetVariant):
available_weights = [
(
@@ -261,6 +265,7 @@ class DenseNet169(DenseNetVariant):
default_size = 224
+@kimm_export(parent_path=["kimm.models", "kimm.models.densenet"])
class DenseNet201(DenseNetVariant):
available_weights = [
(
diff --git a/kimm/models/efficientnet.py b/kimm/_src/models/efficientnet.py
similarity index 92%
rename from kimm/models/efficientnet.py
rename to kimm/_src/models/efficientnet.py
index b534c24..e8027cc 100644
--- a/kimm/models/efficientnet.py
+++ b/kimm/_src/models/efficientnet.py
@@ -5,12 +5,15 @@
from keras import backend
from keras import layers
-from kimm.blocks import apply_conv2d_block
-from kimm.blocks import apply_depthwise_separation_block
-from kimm.blocks import apply_inverted_residual_block
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
-from kimm.utils import make_divisible
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.depthwise_separation import (
+ apply_depthwise_separation_block,
+)
+from kimm._src.blocks.inverted_residual import apply_inverted_residual_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.make_divisble import make_divisible
+from kimm._src.utils.model_registry import add_model_to_registry
# type, repeat, kernel_size, strides, expansion_ratio, channels, se_ratio
# ds: depthwise separation block
@@ -313,6 +316,7 @@ def fix_config(self, config: typing.Dict):
"fix_stem_and_head_channels",
"fix_first_and_last_blocks",
"activation",
+ "config",
]
for k in unused_kwargs:
config.pop(k, None)
@@ -387,6 +391,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB0(EfficientNetVariant):
available_weights = [
(
@@ -410,6 +415,7 @@ class EfficientNetB0(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB1(EfficientNetVariant):
available_weights = [
(
@@ -433,6 +439,7 @@ class EfficientNetB1(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB2(EfficientNetVariant):
available_weights = [
(
@@ -456,6 +463,7 @@ class EfficientNetB2(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB3(EfficientNetVariant):
available_weights = [
(
@@ -479,6 +487,7 @@ class EfficientNetB3(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB4(EfficientNetVariant):
available_weights = [
(
@@ -502,6 +511,7 @@ class EfficientNetB4(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB5(EfficientNetVariant):
available_weights = [
(
@@ -525,6 +535,7 @@ class EfficientNetB5(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB6(EfficientNetVariant):
available_weights = [
(
@@ -548,6 +559,7 @@ class EfficientNetB6(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetB7(EfficientNetVariant):
available_weights = [
(
@@ -571,6 +583,7 @@ class EfficientNetB7(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetLiteB0(EfficientNetVariant):
available_weights = [
(
@@ -594,6 +607,7 @@ class EfficientNetLiteB0(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetLiteB1(EfficientNetVariant):
available_weights = [
(
@@ -617,6 +631,7 @@ class EfficientNetLiteB1(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetLiteB2(EfficientNetVariant):
available_weights = [
(
@@ -640,6 +655,7 @@ class EfficientNetLiteB2(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetLiteB3(EfficientNetVariant):
available_weights = [
(
@@ -663,6 +679,7 @@ class EfficientNetLiteB3(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetLiteB4(EfficientNetVariant):
available_weights = [
(
@@ -686,6 +703,7 @@ class EfficientNetLiteB4(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2S(EfficientNetVariant):
available_feature_keys = [
"STEM_S2",
@@ -713,6 +731,7 @@ class EfficientNetV2S(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2M(EfficientNetVariant):
available_weights = [
(
@@ -736,6 +755,7 @@ class EfficientNetV2M(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2L(EfficientNetVariant):
available_weights = [
(
@@ -759,6 +779,7 @@ class EfficientNetV2L(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2XL(EfficientNetVariant):
available_weights = [
(
@@ -782,6 +803,7 @@ class EfficientNetV2XL(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2B0(EfficientNetVariant):
available_feature_keys = [
"STEM_S2",
@@ -809,6 +831,7 @@ class EfficientNetV2B0(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2B1(EfficientNetVariant):
available_feature_keys = [
"STEM_S2",
@@ -836,6 +859,7 @@ class EfficientNetV2B1(EfficientNetVariant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2B2(EfficientNetVariant):
available_feature_keys = [
"STEM_S2",
@@ -864,6 +888,7 @@ class EfficientNetV2B2(EfficientNetVariant):
round_limit = 0.0 # fix
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class EfficientNetV2B3(EfficientNetVariant):
available_feature_keys = [
"STEM_S2",
@@ -892,6 +917,7 @@ class EfficientNetV2B3(EfficientNetVariant):
round_limit = 0.0 # fix
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class TinyNetA(EfficientNetVariant):
available_weights = [
(
@@ -914,6 +940,7 @@ class TinyNetA(EfficientNetVariant):
round_fn = round # tinynet config
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class TinyNetB(EfficientNetVariant):
available_weights = [
(
@@ -936,6 +963,7 @@ class TinyNetB(EfficientNetVariant):
round_fn = round # tinynet config
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class TinyNetC(EfficientNetVariant):
available_weights = [
(
@@ -958,6 +986,7 @@ class TinyNetC(EfficientNetVariant):
round_fn = round # tinynet config
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class TinyNetD(EfficientNetVariant):
available_weights = [
(
@@ -980,6 +1009,7 @@ class TinyNetD(EfficientNetVariant):
round_fn = round # tinynet config
+@kimm_export(parent_path=["kimm.models", "kimm.models.efficientnet"])
class TinyNetE(EfficientNetVariant):
available_weights = [
(
diff --git a/kimm/models/ghostnet.py b/kimm/_src/models/ghostnet.py
similarity index 94%
rename from kimm/models/ghostnet.py
rename to kimm/_src/models/ghostnet.py
index 92c1534..e4c51b4 100644
--- a/kimm/models/ghostnet.py
+++ b/kimm/_src/models/ghostnet.py
@@ -6,11 +6,12 @@
from keras import layers
from keras import ops
-from kimm.blocks import apply_conv2d_block
-from kimm.blocks import apply_se_block
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
-from kimm.utils import make_divisible
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.squeeze_and_excitation import apply_se_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.make_divisble import make_divisible
+from kimm._src.utils.model_registry import add_model_to_registry
DEFAULT_CONFIG = [
# k, t, c, SE, s
@@ -433,6 +434,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"])
class GhostNet050(GhostNetVariant):
available_weights = []
@@ -442,6 +444,7 @@ class GhostNet050(GhostNetVariant):
version = "v1"
+@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"])
class GhostNet100(GhostNetVariant):
available_weights = [
(
@@ -457,6 +460,7 @@ class GhostNet100(GhostNetVariant):
version = "v1"
+@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"])
class GhostNet130(GhostNetVariant):
available_weights = []
@@ -466,6 +470,7 @@ class GhostNet130(GhostNetVariant):
version = "v1"
+@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"])
class GhostNet100V2(GhostNetVariant):
available_weights = [
(
@@ -481,6 +486,7 @@ class GhostNet100V2(GhostNetVariant):
version = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"])
class GhostNet130V2(GhostNetVariant):
available_weights = [
(
@@ -496,6 +502,7 @@ class GhostNet130V2(GhostNetVariant):
version = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"])
class GhostNet160V2(GhostNetVariant):
available_weights = [
(
diff --git a/kimm/models/hgnet.py b/kimm/_src/models/hgnet.py
similarity index 95%
rename from kimm/models/hgnet.py
rename to kimm/_src/models/hgnet.py
index adf08f6..4e5d6e0 100644
--- a/kimm/models/hgnet.py
+++ b/kimm/_src/models/hgnet.py
@@ -4,10 +4,11 @@
from keras import backend
from keras import layers
-from kimm import layers as kimm_layers
-from kimm.blocks import apply_conv2d_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.learnable_affine import LearnableAffine
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
DEFAULT_V1_TINY_CONFIG = dict(
stem_type="v1",
@@ -144,7 +145,7 @@ def apply_conv_bn_act_block(
name=name,
)
if activation is not None and use_learnable_affine:
- x = kimm_layers.LearnableAffine(name=f"{name}_lab")(x)
+ x = LearnableAffine(name=f"{name}_lab")(x)
return x
@@ -514,7 +515,7 @@ def build_top(
name="head_last_conv_0",
)(x)
if use_learnable_affine:
- x = kimm_layers.LearnableAffine(name="head_last_conv_2")(x)
+ x = LearnableAffine(name="head_last_conv_2")(x)
x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x)
x = layers.Flatten()(x)
x = layers.Dense(
@@ -594,6 +595,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetTiny(HGNetVariant):
available_weights = [
(
@@ -607,6 +609,7 @@ class HGNetTiny(HGNetVariant):
config = "v1_tiny"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetSmall(HGNetVariant):
available_weights = [
(
@@ -620,6 +623,7 @@ class HGNetSmall(HGNetVariant):
config = "v1_small"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetBase(HGNetVariant):
available_weights = [
(
@@ -633,6 +637,7 @@ class HGNetBase(HGNetVariant):
config = "v1_base"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B0(HGNetVariant):
available_weights = [
(
@@ -646,6 +651,7 @@ class HGNetV2B0(HGNetVariant):
config = "v2_b0"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B1(HGNetVariant):
available_weights = [
(
@@ -659,6 +665,7 @@ class HGNetV2B1(HGNetVariant):
config = "v2_b1"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B2(HGNetVariant):
available_weights = [
(
@@ -672,6 +679,7 @@ class HGNetV2B2(HGNetVariant):
config = "v2_b2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B3(HGNetVariant):
available_weights = [
(
@@ -685,6 +693,7 @@ class HGNetV2B3(HGNetVariant):
config = "v2_b3"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B4(HGNetVariant):
available_weights = [
(
@@ -698,6 +707,7 @@ class HGNetV2B4(HGNetVariant):
config = "v2_b4"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B5(HGNetVariant):
available_weights = [
(
@@ -711,6 +721,7 @@ class HGNetV2B5(HGNetVariant):
config = "v2_b5"
+@kimm_export(parent_path=["kimm.models", "kimm.models.hgnet"])
class HGNetV2B6(HGNetVariant):
available_weights = [
(
diff --git a/kimm/models/inception_next.py b/kimm/_src/models/inception_next.py
similarity index 95%
rename from kimm/models/inception_next.py
rename to kimm/_src/models/inception_next.py
index 554cb27..6793a9c 100644
--- a/kimm/models/inception_next.py
+++ b/kimm/_src/models/inception_next.py
@@ -7,10 +7,11 @@
from keras import layers
from keras import ops
-from kimm import layers as kimm_layers
-from kimm.blocks import apply_mlp_block
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.transformer import apply_mlp_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.layer_scale import LayerScale
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def apply_inception_depthwise_conv2d(
@@ -86,7 +87,7 @@ def apply_metanext_block(
use_conv_mlp=True,
name=f"{name}_mlp",
)
- x = kimm_layers.LayerScale(
+ x = LayerScale(
axis=channels_axis,
initializer=initializers.Constant(1e-6),
name=f"{name}_layerscale",
@@ -304,6 +305,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.inception_next"])
class InceptionNeXtTiny(InceptionNeXtVariant):
available_weights = [
(
@@ -320,6 +322,7 @@ class InceptionNeXtTiny(InceptionNeXtVariant):
activation = "gelu"
+@kimm_export(parent_path=["kimm.models", "kimm.models.inception_next"])
class InceptionNeXtSmall(InceptionNeXtVariant):
available_weights = [
(
@@ -336,6 +339,7 @@ class InceptionNeXtSmall(InceptionNeXtVariant):
activation = "gelu"
+@kimm_export(parent_path=["kimm.models", "kimm.models.inception_next"])
class InceptionNeXtBase(InceptionNeXtVariant):
available_weights = [
(
diff --git a/kimm/models/inception_v3.py b/kimm/_src/models/inception_v3.py
similarity index 97%
rename from kimm/models/inception_v3.py
rename to kimm/_src/models/inception_v3.py
index 81db7f0..7c8ae8b 100644
--- a/kimm/models/inception_v3.py
+++ b/kimm/_src/models/inception_v3.py
@@ -5,9 +5,10 @@
from keras import backend
from keras import layers
-from kimm.blocks import apply_conv2d_block
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
_apply_conv2d_block = functools.partial(
apply_conv2d_block, activation="relu", bn_epsilon=1e-3, padding="valid"
@@ -312,6 +313,7 @@ def fix_config(self, config: typing.Dict):
# Model Definition
+@kimm_export(parent_path=["kimm.models", "kimm.models.inception_v3"])
class InceptionV3(InceptionV3Base):
available_weights = [
(
diff --git a/kimm/models/mobilenet_v2.py b/kimm/_src/models/mobilenet_v2.py
similarity index 90%
rename from kimm/models/mobilenet_v2.py
rename to kimm/_src/models/mobilenet_v2.py
index f7e284e..e155587 100644
--- a/kimm/models/mobilenet_v2.py
+++ b/kimm/_src/models/mobilenet_v2.py
@@ -3,12 +3,15 @@
import keras
-from kimm.blocks import apply_conv2d_block
-from kimm.blocks import apply_depthwise_separation_block
-from kimm.blocks import apply_inverted_residual_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
-from kimm.utils import make_divisible
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.depthwise_separation import (
+ apply_depthwise_separation_block,
+)
+from kimm._src.blocks.inverted_residual import apply_inverted_residual_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.make_divisble import make_divisible
+from kimm._src.utils.model_registry import add_model_to_registry
DEFAULT_CONFIG = [
# type, repeat, kernel_size, strides, expansion_ratio, channels
@@ -194,6 +197,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v2"])
class MobileNetV2W050(MobileNetV2Variant):
available_weights = [
(
@@ -210,6 +214,7 @@ class MobileNetV2W050(MobileNetV2Variant):
config = "default"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v2"])
class MobileNetV2W100(MobileNetV2Variant):
available_weights = [
(
@@ -226,6 +231,7 @@ class MobileNetV2W100(MobileNetV2Variant):
config = "default"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v2"])
class MobileNetV2W110(MobileNetV2Variant):
available_weights = [
(
@@ -242,6 +248,7 @@ class MobileNetV2W110(MobileNetV2Variant):
config = "default"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v2"])
class MobileNetV2W120(MobileNetV2Variant):
available_weights = [
(
@@ -258,6 +265,7 @@ class MobileNetV2W120(MobileNetV2Variant):
config = "default"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v2"])
class MobileNetV2W140(MobileNetV2Variant):
available_weights = [
(
diff --git a/kimm/models/mobilenet_v3.py b/kimm/_src/models/mobilenet_v3.py
similarity index 93%
rename from kimm/models/mobilenet_v3.py
rename to kimm/_src/models/mobilenet_v3.py
index 1013b24..48d0f33 100644
--- a/kimm/models/mobilenet_v3.py
+++ b/kimm/_src/models/mobilenet_v3.py
@@ -5,12 +5,15 @@
import keras
from keras import layers
-from kimm.blocks import apply_conv2d_block
-from kimm.blocks import apply_depthwise_separation_block
-from kimm.blocks import apply_inverted_residual_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
-from kimm.utils import make_divisible
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.depthwise_separation import (
+ apply_depthwise_separation_block,
+)
+from kimm._src.blocks.inverted_residual import apply_inverted_residual_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.make_divisble import make_divisible
+from kimm._src.utils.model_registry import add_model_to_registry
DEFAULT_SMALL_CONFIG = [
# type, repeat, kernel_size, strides, expansion_ratio, channels, se_ratio,
@@ -360,6 +363,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class MobileNetV3W050Small(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -380,6 +384,7 @@ class MobileNetV3W050Small(MobileNetV3Variant):
config = "small"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class MobileNetV3W075Small(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -400,6 +405,7 @@ class MobileNetV3W075Small(MobileNetV3Variant):
config = "small"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class MobileNetV3W100Small(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -420,6 +426,7 @@ class MobileNetV3W100Small(MobileNetV3Variant):
config = "small"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class MobileNetV3W100SmallMinimal(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -446,6 +453,7 @@ class MobileNetV3W100SmallMinimal(MobileNetV3Variant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class MobileNetV3W100Large(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -482,6 +490,7 @@ def build_preprocessing(self, inputs, mode="imagenet"):
return super().build_preprocessing(inputs, mode)
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class MobileNetV3W100LargeMinimal(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -511,6 +520,7 @@ class MobileNetV3W100LargeMinimal(MobileNetV3Variant):
padding = "same"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class LCNet035(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -525,6 +535,7 @@ class LCNet035(MobileNetV3Variant):
config = "lcnet"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class LCNet050(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -545,6 +556,7 @@ class LCNet050(MobileNetV3Variant):
config = "lcnet"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class LCNet075(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -565,6 +577,7 @@ class LCNet075(MobileNetV3Variant):
config = "lcnet"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class LCNet100(MobileNetV3Variant):
available_feature_keys = [
"STEM_S2",
@@ -585,6 +598,7 @@ class LCNet100(MobileNetV3Variant):
config = "lcnet"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilenet_v3"])
class LCNet150(MobileNetV3):
available_feature_keys = [
"STEM_S2",
diff --git a/kimm/models/mobileone.py b/kimm/_src/models/mobileone.py
similarity index 93%
rename from kimm/models/mobileone.py
rename to kimm/_src/models/mobileone.py
index 36bb67d..78fd912 100644
--- a/kimm/models/mobileone.py
+++ b/kimm/_src/models/mobileone.py
@@ -3,9 +3,10 @@
import keras
from keras import backend
-from kimm import layers as kimm_layers
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.mobile_one_conv2d import MobileOneConv2D
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
@keras.saving.register_keras_serializable(package="kimm")
@@ -54,7 +55,7 @@ def __init__(
features = {}
# stem
- x = kimm_layers.MobileOneConv2D(
+ x = MobileOneConv2D(
stem_channels,
3,
2,
@@ -82,7 +83,7 @@ def __init__(
name1 = f"stages_{current_stage_idx}_{current_block_idx}"
name2 = f"stages_{current_stage_idx}_{current_block_idx+1}"
# Depthwise
- x = kimm_layers.MobileOneConv2D(
+ x = MobileOneConv2D(
input_channels,
3,
strides,
@@ -94,7 +95,7 @@ def __init__(
name=name1,
)(x)
# Pointwise
- x = kimm_layers.MobileOneConv2D(
+ x = MobileOneConv2D(
c,
1,
1,
@@ -215,6 +216,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobileone"])
class MobileOneS0(MobileOneVariant):
available_weights = [
(
@@ -231,6 +233,7 @@ class MobileOneS0(MobileOneVariant):
branch_size = 4
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobileone"])
class MobileOneS1(MobileOneVariant):
available_weights = [
(
@@ -247,6 +250,7 @@ class MobileOneS1(MobileOneVariant):
branch_size = 1
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobileone"])
class MobileOneS2(MobileOneVariant):
available_weights = [
(
@@ -263,6 +267,7 @@ class MobileOneS2(MobileOneVariant):
branch_size = 1
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobileone"])
class MobileOneS3(MobileOneVariant):
available_weights = [
(
diff --git a/kimm/models/mobilevit.py b/kimm/_src/models/mobilevit.py
similarity index 95%
rename from kimm/models/mobilevit.py
rename to kimm/_src/models/mobilevit.py
index 8565c85..d2e51d7 100644
--- a/kimm/models/mobilevit.py
+++ b/kimm/_src/models/mobilevit.py
@@ -6,13 +6,14 @@
from keras import layers
from keras import ops
-from kimm.blocks import apply_conv2d_block
-from kimm.blocks import apply_inverted_residual_block
-from kimm.blocks import apply_mlp_block
-from kimm.blocks import apply_transformer_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
-from kimm.utils import make_divisible
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.inverted_residual import apply_inverted_residual_block
+from kimm._src.blocks.transformer import apply_mlp_block
+from kimm._src.blocks.transformer import apply_transformer_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.make_divisble import make_divisible
+from kimm._src.utils.model_registry import add_model_to_registry
# type, repeat, kernel_size, channels, strides, expansion_ratio,
# transformer_dim, transformer_depth, patch_size
@@ -679,6 +680,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTXXS(MobileViTVariant):
available_weights = [
(
@@ -695,6 +697,7 @@ class MobileViTXXS(MobileViTVariant):
config = "v1_xxs"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTXS(MobileViTVariant):
available_weights = [
(
@@ -711,6 +714,7 @@ class MobileViTXS(MobileViTVariant):
config = "v1_xs"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTS(MobileViTVariant):
available_weights = [
(
@@ -771,6 +775,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W050(MobileViTV2Variant):
available_weights = [
(
@@ -786,6 +791,7 @@ class MobileViTV2W050(MobileViTV2Variant):
config = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W075(MobileViTV2Variant):
available_weights = [
(
@@ -801,6 +807,7 @@ class MobileViTV2W075(MobileViTV2Variant):
config = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W100(MobileViTV2Variant):
available_weights = [
(
@@ -816,6 +823,7 @@ class MobileViTV2W100(MobileViTV2Variant):
config = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W125(MobileViTV2Variant):
available_weights = [
(
@@ -831,6 +839,7 @@ class MobileViTV2W125(MobileViTV2Variant):
config = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W150(MobileViTV2Variant):
available_weights = [
(
@@ -846,6 +855,7 @@ class MobileViTV2W150(MobileViTV2Variant):
config = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W175(MobileViTV2Variant):
available_weights = [
(
@@ -861,6 +871,7 @@ class MobileViTV2W175(MobileViTV2Variant):
config = "v2"
+@kimm_export(parent_path=["kimm.models", "kimm.models.mobilevit"])
class MobileViTV2W200(MobileViTV2Variant):
available_weights = [
(
diff --git a/kimm/models/models_test.py b/kimm/_src/models/models_test.py
similarity index 58%
rename from kimm/models/models_test.py
rename to kimm/_src/models/models_test.py
index 8167947..e2c385a 100644
--- a/kimm/models/models_test.py
+++ b/kimm/_src/models/models_test.py
@@ -2,26 +2,88 @@
import pytest
import tensorflow as tf
from absl.testing import parameterized
-from keras import applications
-from keras import backend
-from keras import models
-from keras import ops
-from keras import random
-from keras import utils
from keras.src import testing
-from kimm import models as kimm_models
-from kimm.utils import make_divisible
+from kimm._src import models as kimm_models
+from kimm._src.utils.make_divisble import make_divisible
-decode_predictions = applications.imagenet_utils.decode_predictions
+decode_predictions = keras.applications.imagenet_utils.decode_predictions
+
+# Test BaseModel
+
+
+class SampleModel(kimm_models.BaseModel):
+ available_feature_keys = [f"S{2**i}" for i in range(1, 6)]
+
+ def __init__(self, **kwargs):
+ self.set_properties(kwargs)
+ inputs = keras.layers.Input(shape=[224, 224, 3])
+
+ features = {}
+ s2 = keras.layers.Conv2D(3, 1, 2, use_bias=False)(inputs)
+ features["S2"] = s2
+ s4 = keras.layers.Conv2D(3, 1, 2, use_bias=False)(s2)
+ features["S4"] = s4
+ s8 = keras.layers.Conv2D(3, 1, 2, use_bias=False)(s4)
+ features["S8"] = s8
+ s16 = keras.layers.Conv2D(3, 1, 2, use_bias=False)(s8)
+ features["S16"] = s16
+ s32 = keras.layers.Conv2D(3, 1, 2, use_bias=False)(s16)
+ features["S32"] = s32
+ outputs = keras.layers.GlobalAveragePooling2D()(s32)
+ super().__init__(
+ inputs=inputs, outputs=outputs, features=features, **kwargs
+ )
+
+
+class BaseModelTest(testing.TestCase, parameterized.TestCase):
+ def test_feature_extractor(self):
+ x = keras.random.uniform([1, 224, 224, 3])
+
+ # Test availiable_feature_keys
+ self.assertContainsSubset(
+ ["S2", "S4", "S8", "S16", "S32"],
+ SampleModel.available_feature_keys,
+ )
+
+ # Test feature_extractor=False
+ model = SampleModel()
+ y = model(x, training=False)
+ self.assertNotIsInstance(y, dict)
+ self.assertEqual(list(y.shape), [1, 3])
+
+ # Test feature_extractor=True
+ model = SampleModel(feature_extractor=True)
+ y = model(x, training=False)
+ self.assertIsInstance(y, dict)
+ self.assertEqual(list(y["S2"].shape), [1, 112, 112, 3])
+ self.assertEqual(list(y["S32"].shape), [1, 7, 7, 3])
+
+ # Test feature_extractor=True with feature_keys
+ model = SampleModel(
+ include_top=False,
+ feature_extractor=True,
+ feature_keys=["S2", "S16", "S32"],
+ )
+ y = model(x, training=False)
+ self.assertIsInstance(y, dict)
+ self.assertNotIn("S4", y)
+ self.assertNotIn("S8", y)
+ self.assertEqual(list(y["S2"].shape), [1, 112, 112, 3])
+ self.assertEqual(list(y["S16"].shape), [1, 14, 14, 3])
+ self.assertEqual(list(y["S32"].shape), [1, 7, 7, 3])
+ self.assertNotIn("TOP", y)
+
+
+# Test some small models
# name, class, default_size, features (name, shape),
# weights (defaults to imagenet)
MODEL_CONFIGS = [
# convmixer
(
- kimm_models.ConvMixer736D32.__name__,
- kimm_models.ConvMixer736D32,
+ kimm_models.convmixer.ConvMixer736D32.__name__,
+ kimm_models.convmixer.ConvMixer736D32,
224,
[
("STEM", [1, 32, 32, 768]),
@@ -30,8 +92,8 @@
),
# convnext
(
- kimm_models.ConvNeXtAtto.__name__,
- kimm_models.ConvNeXtAtto,
+ kimm_models.convnext.ConvNeXtAtto.__name__,
+ kimm_models.convnext.ConvNeXtAtto,
288,
[
("STEM_S4", [1, 72, 72, 40]),
@@ -43,8 +105,8 @@
),
# densenet
(
- kimm_models.DenseNet121.__name__,
- kimm_models.DenseNet121,
+ kimm_models.densenet.DenseNet121.__name__,
+ kimm_models.densenet.DenseNet121,
224,
[
("STEM_S4", [1, 56, 56, 64]),
@@ -56,8 +118,8 @@
),
# efficientnet
(
- kimm_models.EfficientNetB2.__name__,
- kimm_models.EfficientNetB2,
+ kimm_models.efficientnet.EfficientNetB2.__name__,
+ kimm_models.efficientnet.EfficientNetB2,
260,
[
("STEM_S2", [1, 130, 130, make_divisible(32 * 1.1)]),
@@ -68,8 +130,8 @@
],
),
(
- kimm_models.EfficientNetLiteB2.__name__,
- kimm_models.EfficientNetLiteB2,
+ kimm_models.efficientnet.EfficientNetLiteB2.__name__,
+ kimm_models.efficientnet.EfficientNetLiteB2,
260,
[
("STEM_S2", [1, 130, 130, make_divisible(32 * 1.1)]),
@@ -80,8 +142,8 @@
],
),
(
- kimm_models.EfficientNetV2S.__name__,
- kimm_models.EfficientNetV2S,
+ kimm_models.efficientnet.EfficientNetV2S.__name__,
+ kimm_models.efficientnet.EfficientNetV2S,
300,
[
("STEM_S2", [1, 150, 150, make_divisible(24 * 1.0)]),
@@ -92,8 +154,8 @@
],
),
(
- kimm_models.EfficientNetV2B0.__name__,
- kimm_models.EfficientNetV2B0,
+ kimm_models.efficientnet.EfficientNetV2B0.__name__,
+ kimm_models.efficientnet.EfficientNetV2B0,
192,
[
("STEM_S2", [1, 96, 96, make_divisible(32 * 1.0)]),
@@ -104,8 +166,8 @@
],
),
(
- kimm_models.TinyNetE.__name__,
- kimm_models.TinyNetE,
+ kimm_models.efficientnet.TinyNetE.__name__,
+ kimm_models.efficientnet.TinyNetE,
106,
[
("STEM_S2", [1, 53, 53, 32]),
@@ -117,8 +179,8 @@
),
# ghostnet
(
- kimm_models.GhostNet100.__name__,
- kimm_models.GhostNet100,
+ kimm_models.ghostnet.GhostNet100.__name__,
+ kimm_models.ghostnet.GhostNet100,
224,
[
("STEM_S2", [1, 112, 112, 16]),
@@ -129,8 +191,8 @@
],
),
(
- kimm_models.GhostNet100V2.__name__,
- kimm_models.GhostNet100V2,
+ kimm_models.ghostnet.GhostNet100V2.__name__,
+ kimm_models.ghostnet.GhostNet100V2,
224,
[
("STEM_S2", [1, 112, 112, 16]),
@@ -142,8 +204,8 @@
),
# hgnet
(
- kimm_models.HGNetTiny.__name__,
- kimm_models.HGNetTiny,
+ kimm_models.hgnet.HGNetTiny.__name__,
+ kimm_models.hgnet.HGNetTiny,
224,
[
("STEM_S4", [1, 56, 56, 96]),
@@ -154,8 +216,8 @@
],
),
(
- kimm_models.HGNetV2B0.__name__,
- kimm_models.HGNetV2B0,
+ kimm_models.hgnet.HGNetV2B0.__name__,
+ kimm_models.hgnet.HGNetV2B0,
224,
[
("STEM_S4", [1, 56, 56, 16]),
@@ -167,8 +229,8 @@
),
# inception_next
(
- kimm_models.InceptionNeXtTiny.__name__,
- kimm_models.InceptionNeXtTiny,
+ kimm_models.inception_next.InceptionNeXtTiny.__name__,
+ kimm_models.inception_next.InceptionNeXtTiny,
224,
[
("STEM_S4", [1, 56, 56, 96]),
@@ -180,8 +242,8 @@
),
# inception_v3
(
- kimm_models.InceptionV3.__name__,
- kimm_models.InceptionV3,
+ kimm_models.inception_v3.InceptionV3.__name__,
+ kimm_models.inception_v3.InceptionV3,
299,
[
("STEM_S2", [1, 147, 147, 64]),
@@ -193,8 +255,8 @@
),
# mobilenet_v2
(
- kimm_models.MobileNetV2W050.__name__,
- kimm_models.MobileNetV2W050,
+ kimm_models.mobilenet_v2.MobileNetV2W050.__name__,
+ kimm_models.mobilenet_v2.MobileNetV2W050,
224,
[
("STEM_S2", [1, 112, 112, make_divisible(32 * 0.5)]),
@@ -206,8 +268,8 @@
),
# mobilenet_v3
(
- kimm_models.LCNet100.__name__,
- kimm_models.LCNet100,
+ kimm_models.mobilenet_v3.LCNet100.__name__,
+ kimm_models.mobilenet_v3.LCNet100,
224,
[
("STEM_S2", [1, 112, 112, make_divisible(16 * 1.0)]),
@@ -218,8 +280,8 @@
],
),
(
- kimm_models.MobileNetV3W050Small.__name__,
- kimm_models.MobileNetV3W050Small,
+ kimm_models.mobilenet_v3.MobileNetV3W050Small.__name__,
+ kimm_models.mobilenet_v3.MobileNetV3W050Small,
224,
[
("STEM_S2", [1, 112, 112, 16]),
@@ -230,8 +292,8 @@
],
),
(
- kimm_models.MobileNetV3W100SmallMinimal.__name__,
- kimm_models.MobileNetV3W100SmallMinimal,
+ kimm_models.mobilenet_v3.MobileNetV3W100SmallMinimal.__name__,
+ kimm_models.mobilenet_v3.MobileNetV3W100SmallMinimal,
224,
[
("STEM_S2", [1, 112, 112, make_divisible(16 * 1.0)]),
@@ -243,8 +305,8 @@
),
# mobileone
(
- kimm_models.MobileOneS0.__name__,
- kimm_models.MobileOneS0,
+ kimm_models.mobileone.MobileOneS0.__name__,
+ kimm_models.mobileone.MobileOneS0,
224,
[
("STEM_S2", [1, 112, 112, 48]),
@@ -256,8 +318,8 @@
),
# mobilevit
(
- kimm_models.MobileViTS.__name__,
- kimm_models.MobileViTS,
+ kimm_models.mobilevit.MobileViTS.__name__,
+ kimm_models.mobilevit.MobileViTS,
256,
[
("STEM_S2", [1, 128, 128, 16]),
@@ -269,8 +331,8 @@
),
# mobilevitv2
(
- kimm_models.MobileViTV2W050.__name__,
- kimm_models.MobileViTV2W050,
+ kimm_models.mobilevit.MobileViTV2W050.__name__,
+ kimm_models.mobilevit.MobileViTV2W050,
256,
[
("STEM_S2", [1, 128, 128, 16]),
@@ -282,8 +344,8 @@
),
# regnet
(
- kimm_models.RegNetX002.__name__,
- kimm_models.RegNetX002,
+ kimm_models.regnet.RegNetX002.__name__,
+ kimm_models.regnet.RegNetX002,
224,
[
("STEM_S2", [1, 112, 112, 32]),
@@ -294,8 +356,8 @@
],
),
(
- kimm_models.RegNetY002.__name__,
- kimm_models.RegNetY002,
+ kimm_models.regnet.RegNetY002.__name__,
+ kimm_models.regnet.RegNetY002,
224,
[
("STEM_S2", [1, 112, 112, 32]),
@@ -307,8 +369,8 @@
),
# repvgg
(
- kimm_models.RepVGGA0.__name__,
- kimm_models.RepVGGA0,
+ kimm_models.repvgg.RepVGGA0.__name__,
+ kimm_models.repvgg.RepVGGA0,
224,
[
("STEM_S2", [1, 112, 112, 48]),
@@ -320,8 +382,8 @@
),
# resnet
(
- kimm_models.ResNet18.__name__,
- kimm_models.ResNet18,
+ kimm_models.resnet.ResNet18.__name__,
+ kimm_models.resnet.ResNet18,
224,
[
("STEM_S2", [1, 112, 112, 64]),
@@ -332,8 +394,8 @@
],
),
(
- kimm_models.ResNet50.__name__,
- kimm_models.ResNet50,
+ kimm_models.resnet.ResNet50.__name__,
+ kimm_models.resnet.ResNet50,
224,
[
("STEM_S2", [1, 112, 112, 64]),
@@ -345,8 +407,8 @@
),
# vgg
(
- kimm_models.VGG11.__name__,
- kimm_models.VGG11,
+ kimm_models.vgg.VGG11.__name__,
+ kimm_models.vgg.VGG11,
224,
[
("BLOCK0_S1", [1, 224, 224, 64]),
@@ -360,22 +422,22 @@
),
# vision_transformer
(
- kimm_models.VisionTransformerTiny16.__name__,
- kimm_models.VisionTransformerTiny16,
+ kimm_models.vision_transformer.VisionTransformerTiny16.__name__,
+ kimm_models.vision_transformer.VisionTransformerTiny16,
384,
[*((f"BLOCK{i}", [1, 577, 192]) for i in range(5))],
),
(
- kimm_models.VisionTransformerTiny32.__name__,
- kimm_models.VisionTransformerTiny32,
+ kimm_models.vision_transformer.VisionTransformerTiny32.__name__,
+ kimm_models.vision_transformer.VisionTransformerTiny32,
384,
[*((f"BLOCK{i}", [1, 145, 192]) for i in range(5))],
None, # no weights
),
# xception
(
- kimm_models.Xception.__name__,
- kimm_models.Xception,
+ kimm_models.xception.Xception.__name__,
+ kimm_models.xception.Xception,
299,
[
("STEM_S2", [1, 147, 147, 64]),
@@ -389,67 +451,77 @@
@pytest.mark.requires_trainable_backend # numpy is too slow to test
-class ModelTest(testing.TestCase, parameterized.TestCase):
+class ModelsTest(testing.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
- cls.original_image_data_format = backend.image_data_format()
+ cls.original_image_data_format = keras.backend.image_data_format()
@classmethod
def tearDownClass(cls):
- backend.set_image_data_format(cls.original_image_data_format)
+ keras.backend.set_image_data_format(cls.original_image_data_format)
@parameterized.named_parameters(MODEL_CONFIGS)
- def test_model_base_channels_last(
- self, model_class, image_size, features, weights="imagenet"
+ def test_predict(
+ self,
+ model_class,
+ image_size,
+ features,
+ weights="imagenet",
):
- backend.set_image_data_format("channels_last")
- model = model_class(weights=weights)
+ # We also enable feature_extractor=True in model instantiation to
+ # speed up the testing
+
+ # Load the image
image_path = keras.utils.get_file(
"elephant.png",
"https://github.com/james77777778/keras-image-models/releases/download/0.1.0/elephant.png",
)
- # preprocessing
- image = utils.load_img(image_path, target_size=(image_size, image_size))
- image = utils.img_to_array(image, data_format="channels_last")
- x = ops.convert_to_tensor(image)
- x = ops.expand_dims(x, axis=0)
+ image = keras.utils.load_img(
+ image_path, target_size=(image_size, image_size)
+ )
+
+ # Test channels_last and feature_extractor=True
+ keras.backend.set_image_data_format("channels_last")
+ model = model_class(weights=weights, feature_extractor=True)
+ x = keras.utils.img_to_array(image, data_format="channels_last")
+ x = keras.ops.expand_dims(keras.ops.convert_to_tensor(x), axis=0)
y = model(x, training=False)
+ # Verify output correctness
+ prob = y["TOP"]
if weights == "imagenet":
- names = [p[1] for p in decode_predictions(y)[0]]
+ names = [p[1] for p in decode_predictions(prob)[0]]
# Test correct label is in top 3 (weak correctness test).
self.assertIn("African_elephant", names[:3])
elif weights is None:
- self.assertEqual(list(y.shape), [1, 1000])
+ self.assertEqual(list(prob.shape), [1, 1000])
- @parameterized.named_parameters(MODEL_CONFIGS)
- def test_model_base_channels_first(
- self, model_class, image_size, features, weights="imagenet"
- ):
+ # Verify features
+ self.assertIsInstance(y, dict)
+ self.assertContainsSubset(
+ model_class.available_feature_keys, list(y.keys())
+ )
+ for feature_info in features:
+ name, shape = feature_info
+ self.assertEqual(list(y[name].shape), shape)
+
+ # Test channels_first
if (
len(tf.config.list_physical_devices("GPU")) == 0
- and backend.backend() == "tensorflow"
+ and keras.backend.backend() == "tensorflow"
):
- self.skipTest(
- "Conv2D doesn't support channels_first using CPU with "
- "tensorflow backend"
- )
+ # TensorFlow doesn't support channels_first using CPU
+ return
- backend.set_image_data_format("channels_first")
+ keras.backend.set_image_data_format("channels_first")
model = model_class(weights=weights)
- image_path = keras.utils.get_file(
- "elephant.png",
- "https://github.com/james77777778/keras-image-models/releases/download/0.1.0/elephant.png",
- )
- # preprocessing
- image = utils.load_img(image_path, target_size=(image_size, image_size))
- image = utils.img_to_array(image, data_format="channels_first")
- x = ops.convert_to_tensor(image)
- x = ops.expand_dims(x, axis=0)
+ x = keras.utils.img_to_array(image, data_format="channels_first")
+ x = keras.ops.expand_dims(keras.ops.convert_to_tensor(x), axis=0)
y = model(x, training=False)
+ # Verify output correctness
if weights == "imagenet":
names = [p[1] for p in decode_predictions(y)[0]]
# Test correct label is in top 3 (weak correctness test).
@@ -457,30 +529,24 @@ def test_model_base_channels_first(
elif weights is None:
self.assertEqual(list(y.shape), [1, 1000])
- @parameterized.named_parameters(MODEL_CONFIGS)
- def test_model_feature_extractor(
- self, model_class, image_size, features, weights="imagenet"
- ):
- backend.set_image_data_format("channels_last")
- x = random.uniform([1, image_size, image_size, 3]) * 255.0
- model = model_class(weights=None, feature_extractor=True)
-
- y = model(x, training=False)
-
- self.assertIsInstance(y, dict)
- self.assertContainsSubset(
- model_class.available_feature_keys, list(y.keys())
- )
- for feature_info in features:
- name, shape = feature_info
- self.assertEqual(list(y[name].shape), shape)
-
@parameterized.named_parameters(
- (kimm_models.RepVGGA0.__name__, kimm_models.RepVGGA0, 224),
- (kimm_models.MobileOneS0.__name__, kimm_models.MobileOneS0, 224),
+ (
+ kimm_models.repvgg.RepVGGA0.__name__,
+ kimm_models.repvgg.RepVGGA0,
+ 224,
+ ),
+ (
+ kimm_models.mobileone.MobileOneS0.__name__,
+ kimm_models.mobileone.MobileOneS0,
+ 224,
+ ),
)
- def test_model_get_reparameterized_model(self, model_class, image_size):
- x = random.uniform([1, image_size, image_size, 3]) * 255.0
+ def test_get_reparameterized_model(
+ self,
+ model_class,
+ image_size,
+ ):
+ x = keras.random.uniform([1, image_size, image_size, 3]) * 255.0
model = model_class()
reparameterized_model = model.get_reparameterized_model()
@@ -491,17 +557,23 @@ def test_model_get_reparameterized_model(self, model_class, image_size):
@pytest.mark.serialization
@parameterized.named_parameters(MODEL_CONFIGS)
- def test_model_serialization(
- self, model_class, image_size, features, weights="imagenet"
+ def test_serialization(
+ self,
+ model_class,
+ image_size,
+ features,
+ weights="imagenet",
):
- backend.set_image_data_format("channels_last")
- x = random.uniform([1, image_size, image_size, 3]) * 255.0
+ keras.backend.set_image_data_format("channels_last")
+ x = keras.random.uniform([1, image_size, image_size, 3]) * 255.0
temp_dir = self.get_temp_dir()
model1 = model_class(weights=None)
+ if hasattr(model1, "get_reparameterized_model"):
+ model1 = model1.get_reparameterized_model()
y1 = model1(x, training=False)
model1.save(temp_dir + "/model.keras")
- model2 = models.load_model(temp_dir + "/model.keras")
+ model2 = keras.models.load_model(temp_dir + "/model.keras")
y2 = model2(x, training=False)
self.assertAllClose(y1, y2)
diff --git a/kimm/models/regnet.py b/kimm/_src/models/regnet.py
similarity index 90%
rename from kimm/models/regnet.py
rename to kimm/_src/models/regnet.py
index c0d6c62..0ad42f9 100644
--- a/kimm/models/regnet.py
+++ b/kimm/_src/models/regnet.py
@@ -5,10 +5,11 @@
from keras import backend
from keras import layers
-from kimm.blocks import apply_conv2d_block
-from kimm.blocks import apply_se_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.squeeze_and_excitation import apply_se_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def _adjust_widths_and_groups(widths, groups, expansion_ratio):
@@ -285,6 +286,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX002(RegNetVariant):
available_weights = [
(
@@ -303,6 +305,7 @@ class RegNetX002(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY002(RegNetVariant):
available_weights = [
(
@@ -321,6 +324,7 @@ class RegNetY002(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX004(RegNetVariant):
available_weights = [
(
@@ -339,6 +343,7 @@ class RegNetX004(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY004(RegNetVariant):
available_weights = [
(
@@ -357,6 +362,7 @@ class RegNetY004(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX006(RegNetVariant):
available_weights = [
(
@@ -375,6 +381,7 @@ class RegNetX006(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY006(RegNetVariant):
available_weights = [
(
@@ -393,6 +400,7 @@ class RegNetY006(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX008(RegNetVariant):
available_weights = [
(
@@ -411,6 +419,7 @@ class RegNetX008(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY008(RegNetVariant):
available_weights = [
(
@@ -429,6 +438,7 @@ class RegNetY008(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX016(RegNetVariant):
available_weights = [
(
@@ -447,6 +457,7 @@ class RegNetX016(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY016(RegNetVariant):
available_weights = [
(
@@ -465,6 +476,7 @@ class RegNetY016(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX032(RegNetVariant):
available_weights = [
(
@@ -483,6 +495,7 @@ class RegNetX032(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY032(RegNetVariant):
available_weights = [
(
@@ -501,6 +514,7 @@ class RegNetY032(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX040(RegNetVariant):
available_weights = [
(
@@ -519,6 +533,7 @@ class RegNetX040(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY040(RegNetVariant):
available_weights = [
(
@@ -537,6 +552,7 @@ class RegNetY040(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX064(RegNetVariant):
available_weights = [
(
@@ -555,6 +571,7 @@ class RegNetX064(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY064(RegNetVariant):
available_weights = [
(
@@ -573,6 +590,7 @@ class RegNetY064(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX080(RegNetVariant):
available_weights = [
(
@@ -591,6 +609,7 @@ class RegNetX080(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY080(RegNetVariant):
available_weights = [
(
@@ -609,6 +628,7 @@ class RegNetY080(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX120(RegNetVariant):
available_weights = [
(
@@ -627,6 +647,7 @@ class RegNetX120(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY120(RegNetVariant):
available_weights = [
(
@@ -645,6 +666,7 @@ class RegNetY120(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX160(RegNetVariant):
available_weights = [
(
@@ -663,6 +685,7 @@ class RegNetX160(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY160(RegNetVariant):
available_weights = [
(
@@ -681,6 +704,7 @@ class RegNetY160(RegNetVariant):
se_ratio = 0.25
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetX320(RegNetVariant):
available_weights = [
(
@@ -699,6 +723,7 @@ class RegNetX320(RegNetVariant):
se_ratio = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.regnet"])
class RegNetY320(RegNetVariant):
available_weights = [
(
diff --git a/kimm/models/repvgg.py b/kimm/_src/models/repvgg.py
similarity index 92%
rename from kimm/models/repvgg.py
rename to kimm/_src/models/repvgg.py
index 1b2a402..ffcdd13 100644
--- a/kimm/models/repvgg.py
+++ b/kimm/_src/models/repvgg.py
@@ -3,9 +3,10 @@
import keras
from keras import backend
-from kimm import layers as kimm_layers
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.rep_conv2d import RepConv2D
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
@keras.saving.register_keras_serializable(package="kimm")
@@ -53,7 +54,7 @@ def __init__(
features = {}
# stem
- x = kimm_layers.RepConv2D(
+ x = RepConv2D(
stem_channels,
3,
2,
@@ -77,7 +78,7 @@ def __init__(
input_channels = x.shape[channels_axis]
has_skip = input_channels == c and strides == 1
name = f"stages_{current_stage_idx}_{current_block_idx}"
- x = kimm_layers.RepConv2D(
+ x = RepConv2D(
c,
3,
strides,
@@ -190,6 +191,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGA0(RepVGGVariant):
available_weights = [
(
@@ -205,6 +207,7 @@ class RepVGGA0(RepVGGVariant):
stem_channels = 48
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGA1(RepVGGVariant):
available_weights = [
(
@@ -220,6 +223,7 @@ class RepVGGA1(RepVGGVariant):
stem_channels = 64
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGA2(RepVGGVariant):
available_weights = [
(
@@ -235,6 +239,7 @@ class RepVGGA2(RepVGGVariant):
stem_channels = 64
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGB0(RepVGGVariant):
available_weights = [
(
@@ -250,6 +255,7 @@ class RepVGGB0(RepVGGVariant):
stem_channels = 64
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGB1(RepVGGVariant):
available_weights = [
(
@@ -265,6 +271,7 @@ class RepVGGB1(RepVGGVariant):
stem_channels = 64
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGB2(RepVGGVariant):
available_weights = [
(
@@ -280,6 +287,7 @@ class RepVGGB2(RepVGGVariant):
stem_channels = 64
+@kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"])
class RepVGGB3(RepVGG):
available_weights = [
(
diff --git a/kimm/models/resnet.py b/kimm/_src/models/resnet.py
similarity index 93%
rename from kimm/models/resnet.py
rename to kimm/_src/models/resnet.py
index deeb1c9..9f6d25d 100644
--- a/kimm/models/resnet.py
+++ b/kimm/_src/models/resnet.py
@@ -4,9 +4,10 @@
from keras import backend
from keras import layers
-from kimm.blocks import apply_conv2d_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def apply_basic_block(
@@ -241,6 +242,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.resnet"])
class ResNet18(ResNetVariant):
available_weights = [
(
@@ -255,6 +257,7 @@ class ResNet18(ResNetVariant):
num_blocks = [2, 2, 2, 2]
+@kimm_export(parent_path=["kimm.models", "kimm.models.resnet"])
class ResNet34(ResNetVariant):
available_weights = [
(
@@ -269,6 +272,7 @@ class ResNet34(ResNetVariant):
num_blocks = [3, 4, 6, 3]
+@kimm_export(parent_path=["kimm.models", "kimm.models.resnet"])
class ResNet50(ResNetVariant):
available_weights = [
(
@@ -283,6 +287,7 @@ class ResNet50(ResNetVariant):
num_blocks = [3, 4, 6, 3]
+@kimm_export(parent_path=["kimm.models", "kimm.models.resnet"])
class ResNet101(ResNetVariant):
available_weights = [
(
@@ -297,6 +302,7 @@ class ResNet101(ResNetVariant):
num_blocks = [3, 4, 23, 3]
+@kimm_export(parent_path=["kimm.models", "kimm.models.resnet"])
class ResNet152(ResNetVariant):
available_weights = [
(
diff --git a/kimm/models/vgg.py b/kimm/_src/models/vgg.py
similarity index 94%
rename from kimm/models/vgg.py
rename to kimm/_src/models/vgg.py
index 63acb65..a7b57a7 100644
--- a/kimm/models/vgg.py
+++ b/kimm/_src/models/vgg.py
@@ -4,8 +4,9 @@
from keras import backend
from keras import layers
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
DEFAULT_VGG11_CONFIG = [
64,
@@ -245,6 +246,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.vgg"])
class VGG11(VGGVariant):
available_weights = [
(
@@ -258,6 +260,7 @@ class VGG11(VGGVariant):
config = "vgg11"
+@kimm_export(parent_path=["kimm.models", "kimm.models.vgg"])
class VGG13(VGGVariant):
available_weights = [
(
@@ -271,6 +274,7 @@ class VGG13(VGGVariant):
config = "vgg13"
+@kimm_export(parent_path=["kimm.models", "kimm.models.vgg"])
class VGG16(VGGVariant):
available_weights = [
(
@@ -284,6 +288,7 @@ class VGG16(VGGVariant):
config = "vgg16"
+@kimm_export(parent_path=["kimm.models", "kimm.models.vgg"])
class VGG19(VGGVariant):
available_weights = [
(
diff --git a/kimm/models/vision_transformer.py b/kimm/_src/models/vision_transformer.py
similarity index 91%
rename from kimm/models/vision_transformer.py
rename to kimm/_src/models/vision_transformer.py
index 0295676..a06f894 100644
--- a/kimm/models/vision_transformer.py
+++ b/kimm/_src/models/vision_transformer.py
@@ -6,10 +6,11 @@
from keras import layers
from keras import ops
-from kimm import layers as kimm_layers
-from kimm.blocks import apply_transformer_block
-from kimm.models.base_model import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.blocks.transformer import apply_transformer_block
+from kimm._src.kimm_export import kimm_export
+from kimm._src.layers.position_embedding import PositionEmbedding
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
@keras.saving.register_keras_serializable(package="kimm")
@@ -63,7 +64,7 @@ def __init__(
x = ops.transpose(x, [0, 2, 3, 1])
x = layers.Reshape((-1, embed_dim))(x)
- x = kimm_layers.PositionEmbedding(name="postition_embedding")(x)
+ x = PositionEmbedding(name="postition_embedding")(x)
features["EMBEDDING"] = x
x = layers.Dropout(pos_dropout_rate, name="pos_dropout")(x)
@@ -205,6 +206,7 @@ def __init__(
)
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerTiny16(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -229,6 +231,7 @@ class VisionTransformerTiny16(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerTiny32(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -247,6 +250,7 @@ class VisionTransformerTiny32(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerSmall16(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -271,6 +275,7 @@ class VisionTransformerSmall16(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerSmall32(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -295,6 +300,7 @@ class VisionTransformerSmall32(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerBase16(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -319,6 +325,7 @@ class VisionTransformerBase16(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerBase32(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -343,6 +350,7 @@ class VisionTransformerBase32(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerLarge16(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
@@ -361,6 +369,7 @@ class VisionTransformerLarge16(VisionTransformerVariant):
pos_dropout_rate = 0.0
+@kimm_export(parent_path=["kimm.models", "kimm.models.vision_transformer"])
class VisionTransformerLarge32(VisionTransformerVariant):
available_feature_keys = [
"EMBEDDING",
diff --git a/kimm/models/xception.py b/kimm/_src/models/xception.py
similarity index 96%
rename from kimm/models/xception.py
rename to kimm/_src/models/xception.py
index d421f06..1a38dc1 100644
--- a/kimm/models/xception.py
+++ b/kimm/_src/models/xception.py
@@ -4,8 +4,9 @@
from keras import backend
from keras import layers
-from kimm.models import BaseModel
-from kimm.utils import add_model_to_registry
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import add_model_to_registry
def apply_xception_block(
@@ -157,6 +158,7 @@ def fix_config(self, config: typing.Dict):
# Model Definition
+@kimm_export(parent_path=["kimm.models", "kimm.models.xception"])
class Xception(XceptionBase):
available_weights = [
(
diff --git a/kimm/_src/utils/__init__.py b/kimm/_src/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/kimm/utils/make_divisble.py b/kimm/_src/utils/make_divisble.py
similarity index 100%
rename from kimm/utils/make_divisble.py
rename to kimm/_src/utils/make_divisble.py
diff --git a/kimm/utils/model_registry.py b/kimm/_src/utils/model_registry.py
similarity index 89%
rename from kimm/utils/model_registry.py
rename to kimm/_src/utils/model_registry.py
index 5f5fc40..9b70302 100644
--- a/kimm/utils/model_registry.py
+++ b/kimm/_src/utils/model_registry.py
@@ -2,6 +2,8 @@
import typing
import warnings
+from kimm._src.kimm_export import kimm_export
+
# {
# "name", # str
# "feature_extractor", # bool
@@ -33,9 +35,9 @@ def clear_registry():
def add_model_to_registry(model_cls, weights: typing.Optional[str] = None):
- from kimm.models.base_model import BaseModel
+ from kimm._src.models.base_model import BaseModel
- # deal with __all__
+ # Deal with __all__
mod = sys.modules[model_cls.__module__]
model_name = model_cls.__name__
if hasattr(mod, "__all__"):
@@ -43,7 +45,7 @@ def add_model_to_registry(model_cls, weights: typing.Optional[str] = None):
else:
mod.__all__ = [model_name]
- # add model information
+ # Add model information
feature_extractor = False
feature_keys = []
if issubclass(model_cls, BaseModel):
@@ -71,22 +73,23 @@ def add_model_to_registry(model_cls, weights: typing.Optional[str] = None):
)
+@kimm_export(parent_path=["kimm", "kimm.utils"])
def list_models(
name: typing.Optional[str] = None,
feature_extractor: typing.Optional[bool] = None,
weights: typing.Optional[typing.Union[bool, str]] = None,
-):
+) -> typing.List[str]:
result_names: typing.Set = set()
for info in MODEL_REGISTRY:
- # add by default
+ # Add by default
result_names.add(info["name"])
need_remove = False
- # match string (simple implementation)
+ # Match string (simple implementation)
if name is not None:
need_remove = not _match_string(name, info["name"])
- # filter by feature_extractor and weights
+ # Filter by feature_extractor and weights
if (
feature_extractor is not None
and info["feature_extractor"] is not feature_extractor
@@ -100,7 +103,6 @@ def list_models(
elif isinstance(weights, str):
if weights.lower() != info["weights"]:
need_remove = True
-
if need_remove:
result_names.remove(info["name"])
return sorted(result_names)
diff --git a/kimm/utils/model_registry_test.py b/kimm/_src/utils/model_registry_test.py
similarity index 92%
rename from kimm/utils/model_registry_test.py
rename to kimm/_src/utils/model_registry_test.py
index 93a1047..0122a25 100644
--- a/kimm/utils/model_registry_test.py
+++ b/kimm/_src/utils/model_registry_test.py
@@ -1,11 +1,11 @@
from keras import models
from keras.src import testing
-from kimm.models.base_model import BaseModel
-from kimm.utils.model_registry import MODEL_REGISTRY
-from kimm.utils.model_registry import add_model_to_registry
-from kimm.utils.model_registry import clear_registry
-from kimm.utils.model_registry import list_models
+from kimm._src.models.base_model import BaseModel
+from kimm._src.utils.model_registry import MODEL_REGISTRY
+from kimm._src.utils.model_registry import add_model_to_registry
+from kimm._src.utils.model_registry import clear_registry
+from kimm._src.utils.model_registry import list_models
class DummyModel(models.Model):
diff --git a/kimm/utils/model_utils.py b/kimm/_src/utils/model_utils.py
similarity index 88%
rename from kimm/utils/model_utils.py
rename to kimm/_src/utils/model_utils.py
index 283d99f..f1a370a 100644
--- a/kimm/utils/model_utils.py
+++ b/kimm/_src/utils/model_utils.py
@@ -1,6 +1,8 @@
-from kimm.models.base_model import BaseModel
+from kimm._src.kimm_export import kimm_export
+from kimm._src.models.base_model import BaseModel
+@kimm_export(parent_path=["kimm.utils"])
def get_reparameterized_model(model: BaseModel):
if not hasattr(model, "get_reparameterized_model"):
raise ValueError(
diff --git a/kimm/utils/model_utils_test.py b/kimm/_src/utils/model_utils_test.py
similarity index 89%
rename from kimm/utils/model_utils_test.py
rename to kimm/_src/utils/model_utils_test.py
index 1bb44ed..2cfe374 100644
--- a/kimm/utils/model_utils_test.py
+++ b/kimm/_src/utils/model_utils_test.py
@@ -1,9 +1,9 @@
from keras import random
from keras.src import testing
-from kimm.models.regnet import RegNetX002
-from kimm.models.repvgg import RepVGG
-from kimm.utils.model_utils import get_reparameterized_model
+from kimm._src.models.regnet import RegNetX002
+from kimm._src.models.repvgg import RepVGG
+from kimm._src.utils.model_utils import get_reparameterized_model
class ModelUtilsTest(testing.TestCase):
diff --git a/kimm/utils/module_utils.py b/kimm/_src/utils/module_utils.py
similarity index 100%
rename from kimm/utils/module_utils.py
rename to kimm/_src/utils/module_utils.py
diff --git a/kimm/utils/timm_utils.py b/kimm/_src/utils/timm_utils.py
similarity index 95%
rename from kimm/utils/timm_utils.py
rename to kimm/_src/utils/timm_utils.py
index 1caab2e..c5d80c8 100644
--- a/kimm/utils/timm_utils.py
+++ b/kimm/_src/utils/timm_utils.py
@@ -3,6 +3,8 @@
import keras
import numpy as np
+from kimm._src.kimm_export import kimm_export
+
def _is_useless_weights(name: str):
if "num_batches_tracked" in name:
@@ -18,6 +20,7 @@ def _is_non_trainable_weights(name: str):
return False
+@kimm_export(parent_path=["kimm.timm_utils"])
def separate_torch_state_dict(state_dict: typing.OrderedDict):
trainable_state_dict = state_dict.copy()
non_trainable_state_dict = state_dict.copy()
@@ -39,6 +42,7 @@ def separate_torch_state_dict(state_dict: typing.OrderedDict):
return trainable_state_dict, non_trainable_state_dict
+@kimm_export(parent_path=["kimm.timm_utils"])
def separate_keras_weights(keras_model: keras.Model):
trainable_weights = []
non_trainable_weights = []
@@ -67,6 +71,7 @@ def separate_keras_weights(keras_model: keras.Model):
return trainable_weights, non_trainable_weights
+@kimm_export(parent_path=["kimm.timm_utils"])
def assign_weights(
keras_name: str, keras_weight: keras.Variable, torch_weight: np.ndarray
):
@@ -107,6 +112,7 @@ def assign_weights(
)
+@kimm_export(parent_path=["kimm.timm_utils"])
def is_same_weights(
keras_name: str,
keras_weights: keras.Variable,
diff --git a/kimm/_src/version.py b/kimm/_src/version.py
new file mode 100644
index 0000000..7f0d903
--- /dev/null
+++ b/kimm/_src/version.py
@@ -0,0 +1,8 @@
+from kimm._src.kimm_export import kimm_export
+
+__version__ = "0.2.0"
+
+
+@kimm_export("kimm")
+def version():
+ return __version__
diff --git a/kimm/blocks/__init__.py b/kimm/blocks/__init__.py
index b113ecd..d3cdcd2 100644
--- a/kimm/blocks/__init__.py
+++ b/kimm/blocks/__init__.py
@@ -1,9 +1,14 @@
-from kimm.blocks.base_block import apply_activation
-from kimm.blocks.base_block import apply_conv2d_block
-from kimm.blocks.base_block import apply_se_block
-from kimm.blocks.depthwise_separation_block import (
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.blocks.conv2d import apply_conv2d_block
+from kimm._src.blocks.depthwise_separation import (
apply_depthwise_separation_block,
)
-from kimm.blocks.inverted_residual_block import apply_inverted_residual_block
-from kimm.blocks.transformer_block import apply_mlp_block
-from kimm.blocks.transformer_block import apply_transformer_block
+from kimm._src.blocks.inverted_residual import apply_inverted_residual_block
+from kimm._src.blocks.squeeze_and_excitation import apply_se_block
+from kimm._src.blocks.transformer import apply_mlp_block
+from kimm._src.blocks.transformer import apply_transformer_block
diff --git a/kimm/export/__init__.py b/kimm/export/__init__.py
index a3000a7..a646eaa 100644
--- a/kimm/export/__init__.py
+++ b/kimm/export/__init__.py
@@ -1,2 +1,8 @@
-from kimm.export.export_onnx import export_onnx
-from kimm.export.export_tflite import export_tflite
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.export.export_onnx import export_onnx
+from kimm._src.export.export_tflite import export_tflite
diff --git a/kimm/layers/__init__.py b/kimm/layers/__init__.py
index e85a569..bdbe229 100644
--- a/kimm/layers/__init__.py
+++ b/kimm/layers/__init__.py
@@ -1,6 +1,12 @@
-from kimm.layers.attention import Attention
-from kimm.layers.layer_scale import LayerScale
-from kimm.layers.learnable_affine import LearnableAffine
-from kimm.layers.mobile_one_conv2d import MobileOneConv2D
-from kimm.layers.position_embedding import PositionEmbedding
-from kimm.layers.rep_conv2d import RepConv2D
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.layers.attention import Attention
+from kimm._src.layers.layer_scale import LayerScale
+from kimm._src.layers.learnable_affine import LearnableAffine
+from kimm._src.layers.mobile_one_conv2d import MobileOneConv2D
+from kimm._src.layers.position_embedding import PositionEmbedding
+from kimm._src.layers.rep_conv2d import RepConv2D
diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py
index 2cc253d..688f0cf 100644
--- a/kimm/models/__init__.py
+++ b/kimm/models/__init__.py
@@ -1,19 +1,167 @@
-from kimm.models.base_model import BaseModel
-from kimm.models.convmixer import * # noqa:F403
-from kimm.models.convnext import * # noqa:F403
-from kimm.models.densenet import * # noqa:F403
-from kimm.models.efficientnet import * # noqa:F403
-from kimm.models.ghostnet import * # noqa:F403
-from kimm.models.hgnet import * # noqa:F403
-from kimm.models.inception_next import * # noqa:F403
-from kimm.models.inception_v3 import * # noqa:F403
-from kimm.models.mobilenet_v2 import * # noqa:F403
-from kimm.models.mobilenet_v3 import * # noqa:F403
-from kimm.models.mobileone import * # noqa:F403
-from kimm.models.mobilevit import * # noqa:F403
-from kimm.models.regnet import * # noqa:F403
-from kimm.models.repvgg import * # noqa:F403
-from kimm.models.resnet import * # noqa:F403
-from kimm.models.vgg import * # noqa:F403
-from kimm.models.vision_transformer import * # noqa:F403
-from kimm.models.xception import * # noqa:F403
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.base_model import BaseModel
+from kimm._src.models.convmixer import ConvMixer736D32
+from kimm._src.models.convmixer import ConvMixer1024D20
+from kimm._src.models.convmixer import ConvMixer1536D20
+from kimm._src.models.convnext import ConvNeXtAtto
+from kimm._src.models.convnext import ConvNeXtBase
+from kimm._src.models.convnext import ConvNeXtFemto
+from kimm._src.models.convnext import ConvNeXtLarge
+from kimm._src.models.convnext import ConvNeXtNano
+from kimm._src.models.convnext import ConvNeXtPico
+from kimm._src.models.convnext import ConvNeXtSmall
+from kimm._src.models.convnext import ConvNeXtTiny
+from kimm._src.models.convnext import ConvNeXtXLarge
+from kimm._src.models.densenet import DenseNet121
+from kimm._src.models.densenet import DenseNet161
+from kimm._src.models.densenet import DenseNet169
+from kimm._src.models.densenet import DenseNet201
+from kimm._src.models.efficientnet import EfficientNetB0
+from kimm._src.models.efficientnet import EfficientNetB1
+from kimm._src.models.efficientnet import EfficientNetB2
+from kimm._src.models.efficientnet import EfficientNetB3
+from kimm._src.models.efficientnet import EfficientNetB4
+from kimm._src.models.efficientnet import EfficientNetB5
+from kimm._src.models.efficientnet import EfficientNetB6
+from kimm._src.models.efficientnet import EfficientNetB7
+from kimm._src.models.efficientnet import EfficientNetLiteB0
+from kimm._src.models.efficientnet import EfficientNetLiteB1
+from kimm._src.models.efficientnet import EfficientNetLiteB2
+from kimm._src.models.efficientnet import EfficientNetLiteB3
+from kimm._src.models.efficientnet import EfficientNetLiteB4
+from kimm._src.models.efficientnet import EfficientNetV2B0
+from kimm._src.models.efficientnet import EfficientNetV2B1
+from kimm._src.models.efficientnet import EfficientNetV2B2
+from kimm._src.models.efficientnet import EfficientNetV2B3
+from kimm._src.models.efficientnet import EfficientNetV2L
+from kimm._src.models.efficientnet import EfficientNetV2M
+from kimm._src.models.efficientnet import EfficientNetV2S
+from kimm._src.models.efficientnet import EfficientNetV2XL
+from kimm._src.models.efficientnet import TinyNetA
+from kimm._src.models.efficientnet import TinyNetB
+from kimm._src.models.efficientnet import TinyNetC
+from kimm._src.models.efficientnet import TinyNetD
+from kimm._src.models.efficientnet import TinyNetE
+from kimm._src.models.ghostnet import GhostNet050
+from kimm._src.models.ghostnet import GhostNet100
+from kimm._src.models.ghostnet import GhostNet100V2
+from kimm._src.models.ghostnet import GhostNet130
+from kimm._src.models.ghostnet import GhostNet130V2
+from kimm._src.models.ghostnet import GhostNet160V2
+from kimm._src.models.hgnet import HGNetBase
+from kimm._src.models.hgnet import HGNetSmall
+from kimm._src.models.hgnet import HGNetTiny
+from kimm._src.models.hgnet import HGNetV2B0
+from kimm._src.models.hgnet import HGNetV2B1
+from kimm._src.models.hgnet import HGNetV2B2
+from kimm._src.models.hgnet import HGNetV2B3
+from kimm._src.models.hgnet import HGNetV2B4
+from kimm._src.models.hgnet import HGNetV2B5
+from kimm._src.models.hgnet import HGNetV2B6
+from kimm._src.models.inception_next import InceptionNeXtBase
+from kimm._src.models.inception_next import InceptionNeXtSmall
+from kimm._src.models.inception_next import InceptionNeXtTiny
+from kimm._src.models.inception_v3 import InceptionV3
+from kimm._src.models.mobilenet_v2 import MobileNetV2W050
+from kimm._src.models.mobilenet_v2 import MobileNetV2W100
+from kimm._src.models.mobilenet_v2 import MobileNetV2W110
+from kimm._src.models.mobilenet_v2 import MobileNetV2W120
+from kimm._src.models.mobilenet_v2 import MobileNetV2W140
+from kimm._src.models.mobilenet_v3 import LCNet035
+from kimm._src.models.mobilenet_v3 import LCNet050
+from kimm._src.models.mobilenet_v3 import LCNet075
+from kimm._src.models.mobilenet_v3 import LCNet100
+from kimm._src.models.mobilenet_v3 import LCNet150
+from kimm._src.models.mobilenet_v3 import MobileNetV3W050Small
+from kimm._src.models.mobilenet_v3 import MobileNetV3W075Small
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100Large
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100LargeMinimal
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100Small
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100SmallMinimal
+from kimm._src.models.mobileone import MobileOneS0
+from kimm._src.models.mobileone import MobileOneS1
+from kimm._src.models.mobileone import MobileOneS2
+from kimm._src.models.mobileone import MobileOneS3
+from kimm._src.models.mobilevit import MobileViTS
+from kimm._src.models.mobilevit import MobileViTV2W050
+from kimm._src.models.mobilevit import MobileViTV2W075
+from kimm._src.models.mobilevit import MobileViTV2W100
+from kimm._src.models.mobilevit import MobileViTV2W125
+from kimm._src.models.mobilevit import MobileViTV2W150
+from kimm._src.models.mobilevit import MobileViTV2W175
+from kimm._src.models.mobilevit import MobileViTV2W200
+from kimm._src.models.mobilevit import MobileViTXS
+from kimm._src.models.mobilevit import MobileViTXXS
+from kimm._src.models.regnet import RegNetX002
+from kimm._src.models.regnet import RegNetX004
+from kimm._src.models.regnet import RegNetX006
+from kimm._src.models.regnet import RegNetX008
+from kimm._src.models.regnet import RegNetX016
+from kimm._src.models.regnet import RegNetX032
+from kimm._src.models.regnet import RegNetX040
+from kimm._src.models.regnet import RegNetX064
+from kimm._src.models.regnet import RegNetX080
+from kimm._src.models.regnet import RegNetX120
+from kimm._src.models.regnet import RegNetX160
+from kimm._src.models.regnet import RegNetX320
+from kimm._src.models.regnet import RegNetY002
+from kimm._src.models.regnet import RegNetY004
+from kimm._src.models.regnet import RegNetY006
+from kimm._src.models.regnet import RegNetY008
+from kimm._src.models.regnet import RegNetY016
+from kimm._src.models.regnet import RegNetY032
+from kimm._src.models.regnet import RegNetY040
+from kimm._src.models.regnet import RegNetY064
+from kimm._src.models.regnet import RegNetY080
+from kimm._src.models.regnet import RegNetY120
+from kimm._src.models.regnet import RegNetY160
+from kimm._src.models.regnet import RegNetY320
+from kimm._src.models.repvgg import RepVGGA0
+from kimm._src.models.repvgg import RepVGGA1
+from kimm._src.models.repvgg import RepVGGA2
+from kimm._src.models.repvgg import RepVGGB0
+from kimm._src.models.repvgg import RepVGGB1
+from kimm._src.models.repvgg import RepVGGB2
+from kimm._src.models.repvgg import RepVGGB3
+from kimm._src.models.resnet import ResNet18
+from kimm._src.models.resnet import ResNet34
+from kimm._src.models.resnet import ResNet50
+from kimm._src.models.resnet import ResNet101
+from kimm._src.models.resnet import ResNet152
+from kimm._src.models.vgg import VGG11
+from kimm._src.models.vgg import VGG13
+from kimm._src.models.vgg import VGG16
+from kimm._src.models.vgg import VGG19
+from kimm._src.models.vision_transformer import VisionTransformerBase16
+from kimm._src.models.vision_transformer import VisionTransformerBase32
+from kimm._src.models.vision_transformer import VisionTransformerLarge16
+from kimm._src.models.vision_transformer import VisionTransformerLarge32
+from kimm._src.models.vision_transformer import VisionTransformerSmall16
+from kimm._src.models.vision_transformer import VisionTransformerSmall32
+from kimm._src.models.vision_transformer import VisionTransformerTiny16
+from kimm._src.models.vision_transformer import VisionTransformerTiny32
+from kimm._src.models.xception import Xception
+from kimm.models import base_model
+from kimm.models import convmixer
+from kimm.models import convnext
+from kimm.models import densenet
+from kimm.models import efficientnet
+from kimm.models import ghostnet
+from kimm.models import hgnet
+from kimm.models import inception_next
+from kimm.models import inception_v3
+from kimm.models import mobilenet_v2
+from kimm.models import mobilenet_v3
+from kimm.models import mobileone
+from kimm.models import mobilevit
+from kimm.models import regnet
+from kimm.models import repvgg
+from kimm.models import resnet
+from kimm.models import vgg
+from kimm.models import vision_transformer
+from kimm.models import xception
diff --git a/kimm/models/base_model/__init__.py b/kimm/models/base_model/__init__.py
new file mode 100644
index 0000000..dd62169
--- /dev/null
+++ b/kimm/models/base_model/__init__.py
@@ -0,0 +1,7 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.base_model import BaseModel
diff --git a/kimm/models/base_model_test.py b/kimm/models/base_model_test.py
deleted file mode 100644
index bfba31c..0000000
--- a/kimm/models/base_model_test.py
+++ /dev/null
@@ -1,77 +0,0 @@
-from absl.testing import parameterized
-from keras import layers
-from keras import random
-from keras.src import testing
-
-from kimm.models.base_model import BaseModel
-
-
-class SampleModel(BaseModel):
- def __init__(self, **kwargs):
- self.set_properties(kwargs)
- inputs = layers.Input(shape=[224, 224, 3])
-
- features = {}
- s2 = layers.Conv2D(3, 1, 2, use_bias=False)(inputs)
- features["S2"] = s2
- s4 = layers.Conv2D(3, 1, 2, use_bias=False)(s2)
- features["S4"] = s4
- s8 = layers.Conv2D(3, 1, 2, use_bias=False)(s4)
- features["S8"] = s8
- s16 = layers.Conv2D(3, 1, 2, use_bias=False)(s8)
- features["S16"] = s16
- s32 = layers.Conv2D(3, 1, 2, use_bias=False)(s16)
- features["S32"] = s32
- super().__init__(
- inputs=inputs, outputs=s32, features=features, **kwargs
- )
-
- @staticmethod
- def available_feature_keys():
- # predefined for better UX
- return [f"S{2**i}" for i in range(1, 6)]
-
- def get_config(self):
- return super().get_config()
-
-
-class BaseModelTest(testing.TestCase, parameterized.TestCase):
- def test_feature_extractor(self):
- x = random.uniform([1, 224, 224, 3])
-
- # availiable_feature_keys
- self.assertContainsSubset(
- ["S2", "S4", "S8", "S16", "S32"],
- SampleModel.available_feature_keys(),
- )
-
- # feature_extractor=False
- model = SampleModel()
-
- y = model(x, training=False)
-
- self.assertNotIsInstance(y, dict)
- self.assertEqual(list(y.shape), [1, 7, 7, 3])
-
- # feature_extractor=True
- model = SampleModel(feature_extractor=True)
-
- y = model(x, training=False)
-
- self.assertIsInstance(y, dict)
- self.assertEqual(list(y["S2"].shape), [1, 112, 112, 3])
- self.assertEqual(list(y["S32"].shape), [1, 7, 7, 3])
-
- # feature_extractor=True with feature_keys
- model = SampleModel(
- feature_extractor=True, feature_keys=["S2", "S16", "S32"]
- )
-
- y = model(x, training=False)
-
- self.assertIsInstance(y, dict)
- self.assertNotIn("S4", y)
- self.assertNotIn("S8", y)
- self.assertEqual(list(y["S2"].shape), [1, 112, 112, 3])
- self.assertEqual(list(y["S16"].shape), [1, 14, 14, 3])
- self.assertEqual(list(y["S32"].shape), [1, 7, 7, 3])
diff --git a/kimm/models/convmixer/__init__.py b/kimm/models/convmixer/__init__.py
new file mode 100644
index 0000000..f609fa8
--- /dev/null
+++ b/kimm/models/convmixer/__init__.py
@@ -0,0 +1,9 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.convmixer import ConvMixer736D32
+from kimm._src.models.convmixer import ConvMixer1024D20
+from kimm._src.models.convmixer import ConvMixer1536D20
diff --git a/kimm/models/convnext/__init__.py b/kimm/models/convnext/__init__.py
new file mode 100644
index 0000000..a16129b
--- /dev/null
+++ b/kimm/models/convnext/__init__.py
@@ -0,0 +1,15 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.convnext import ConvNeXtAtto
+from kimm._src.models.convnext import ConvNeXtBase
+from kimm._src.models.convnext import ConvNeXtFemto
+from kimm._src.models.convnext import ConvNeXtLarge
+from kimm._src.models.convnext import ConvNeXtNano
+from kimm._src.models.convnext import ConvNeXtPico
+from kimm._src.models.convnext import ConvNeXtSmall
+from kimm._src.models.convnext import ConvNeXtTiny
+from kimm._src.models.convnext import ConvNeXtXLarge
diff --git a/kimm/models/densenet/__init__.py b/kimm/models/densenet/__init__.py
new file mode 100644
index 0000000..282a027
--- /dev/null
+++ b/kimm/models/densenet/__init__.py
@@ -0,0 +1,10 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.densenet import DenseNet121
+from kimm._src.models.densenet import DenseNet161
+from kimm._src.models.densenet import DenseNet169
+from kimm._src.models.densenet import DenseNet201
diff --git a/kimm/models/efficientnet/__init__.py b/kimm/models/efficientnet/__init__.py
new file mode 100644
index 0000000..5921789
--- /dev/null
+++ b/kimm/models/efficientnet/__init__.py
@@ -0,0 +1,32 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.efficientnet import EfficientNetB0
+from kimm._src.models.efficientnet import EfficientNetB1
+from kimm._src.models.efficientnet import EfficientNetB2
+from kimm._src.models.efficientnet import EfficientNetB3
+from kimm._src.models.efficientnet import EfficientNetB4
+from kimm._src.models.efficientnet import EfficientNetB5
+from kimm._src.models.efficientnet import EfficientNetB6
+from kimm._src.models.efficientnet import EfficientNetB7
+from kimm._src.models.efficientnet import EfficientNetLiteB0
+from kimm._src.models.efficientnet import EfficientNetLiteB1
+from kimm._src.models.efficientnet import EfficientNetLiteB2
+from kimm._src.models.efficientnet import EfficientNetLiteB3
+from kimm._src.models.efficientnet import EfficientNetLiteB4
+from kimm._src.models.efficientnet import EfficientNetV2B0
+from kimm._src.models.efficientnet import EfficientNetV2B1
+from kimm._src.models.efficientnet import EfficientNetV2B2
+from kimm._src.models.efficientnet import EfficientNetV2B3
+from kimm._src.models.efficientnet import EfficientNetV2L
+from kimm._src.models.efficientnet import EfficientNetV2M
+from kimm._src.models.efficientnet import EfficientNetV2S
+from kimm._src.models.efficientnet import EfficientNetV2XL
+from kimm._src.models.efficientnet import TinyNetA
+from kimm._src.models.efficientnet import TinyNetB
+from kimm._src.models.efficientnet import TinyNetC
+from kimm._src.models.efficientnet import TinyNetD
+from kimm._src.models.efficientnet import TinyNetE
diff --git a/kimm/models/ghostnet/__init__.py b/kimm/models/ghostnet/__init__.py
new file mode 100644
index 0000000..d1ae684
--- /dev/null
+++ b/kimm/models/ghostnet/__init__.py
@@ -0,0 +1,12 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.ghostnet import GhostNet050
+from kimm._src.models.ghostnet import GhostNet100
+from kimm._src.models.ghostnet import GhostNet100V2
+from kimm._src.models.ghostnet import GhostNet130
+from kimm._src.models.ghostnet import GhostNet130V2
+from kimm._src.models.ghostnet import GhostNet160V2
diff --git a/kimm/models/hgnet/__init__.py b/kimm/models/hgnet/__init__.py
new file mode 100644
index 0000000..1f3a587
--- /dev/null
+++ b/kimm/models/hgnet/__init__.py
@@ -0,0 +1,16 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.hgnet import HGNetBase
+from kimm._src.models.hgnet import HGNetSmall
+from kimm._src.models.hgnet import HGNetTiny
+from kimm._src.models.hgnet import HGNetV2B0
+from kimm._src.models.hgnet import HGNetV2B1
+from kimm._src.models.hgnet import HGNetV2B2
+from kimm._src.models.hgnet import HGNetV2B3
+from kimm._src.models.hgnet import HGNetV2B4
+from kimm._src.models.hgnet import HGNetV2B5
+from kimm._src.models.hgnet import HGNetV2B6
diff --git a/kimm/models/inception_next/__init__.py b/kimm/models/inception_next/__init__.py
new file mode 100644
index 0000000..9932874
--- /dev/null
+++ b/kimm/models/inception_next/__init__.py
@@ -0,0 +1,9 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.inception_next import InceptionNeXtBase
+from kimm._src.models.inception_next import InceptionNeXtSmall
+from kimm._src.models.inception_next import InceptionNeXtTiny
diff --git a/kimm/models/inception_v3/__init__.py b/kimm/models/inception_v3/__init__.py
new file mode 100644
index 0000000..f8cb17d
--- /dev/null
+++ b/kimm/models/inception_v3/__init__.py
@@ -0,0 +1,7 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.inception_v3 import InceptionV3
diff --git a/kimm/models/mobilenet_v2/__init__.py b/kimm/models/mobilenet_v2/__init__.py
new file mode 100644
index 0000000..4e4c253
--- /dev/null
+++ b/kimm/models/mobilenet_v2/__init__.py
@@ -0,0 +1,11 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.mobilenet_v2 import MobileNetV2W050
+from kimm._src.models.mobilenet_v2 import MobileNetV2W100
+from kimm._src.models.mobilenet_v2 import MobileNetV2W110
+from kimm._src.models.mobilenet_v2 import MobileNetV2W120
+from kimm._src.models.mobilenet_v2 import MobileNetV2W140
diff --git a/kimm/models/mobilenet_v3/__init__.py b/kimm/models/mobilenet_v3/__init__.py
new file mode 100644
index 0000000..ecd8e1d
--- /dev/null
+++ b/kimm/models/mobilenet_v3/__init__.py
@@ -0,0 +1,17 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.mobilenet_v3 import LCNet035
+from kimm._src.models.mobilenet_v3 import LCNet050
+from kimm._src.models.mobilenet_v3 import LCNet075
+from kimm._src.models.mobilenet_v3 import LCNet100
+from kimm._src.models.mobilenet_v3 import LCNet150
+from kimm._src.models.mobilenet_v3 import MobileNetV3W050Small
+from kimm._src.models.mobilenet_v3 import MobileNetV3W075Small
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100Large
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100LargeMinimal
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100Small
+from kimm._src.models.mobilenet_v3 import MobileNetV3W100SmallMinimal
diff --git a/kimm/models/mobileone/__init__.py b/kimm/models/mobileone/__init__.py
new file mode 100644
index 0000000..150978f
--- /dev/null
+++ b/kimm/models/mobileone/__init__.py
@@ -0,0 +1,10 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.mobileone import MobileOneS0
+from kimm._src.models.mobileone import MobileOneS1
+from kimm._src.models.mobileone import MobileOneS2
+from kimm._src.models.mobileone import MobileOneS3
diff --git a/kimm/models/mobilevit/__init__.py b/kimm/models/mobilevit/__init__.py
new file mode 100644
index 0000000..93746b9
--- /dev/null
+++ b/kimm/models/mobilevit/__init__.py
@@ -0,0 +1,16 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.mobilevit import MobileViTS
+from kimm._src.models.mobilevit import MobileViTV2W050
+from kimm._src.models.mobilevit import MobileViTV2W075
+from kimm._src.models.mobilevit import MobileViTV2W100
+from kimm._src.models.mobilevit import MobileViTV2W125
+from kimm._src.models.mobilevit import MobileViTV2W150
+from kimm._src.models.mobilevit import MobileViTV2W175
+from kimm._src.models.mobilevit import MobileViTV2W200
+from kimm._src.models.mobilevit import MobileViTXS
+from kimm._src.models.mobilevit import MobileViTXXS
diff --git a/kimm/models/regnet/__init__.py b/kimm/models/regnet/__init__.py
new file mode 100644
index 0000000..160200d
--- /dev/null
+++ b/kimm/models/regnet/__init__.py
@@ -0,0 +1,30 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.regnet import RegNetX002
+from kimm._src.models.regnet import RegNetX004
+from kimm._src.models.regnet import RegNetX006
+from kimm._src.models.regnet import RegNetX008
+from kimm._src.models.regnet import RegNetX016
+from kimm._src.models.regnet import RegNetX032
+from kimm._src.models.regnet import RegNetX040
+from kimm._src.models.regnet import RegNetX064
+from kimm._src.models.regnet import RegNetX080
+from kimm._src.models.regnet import RegNetX120
+from kimm._src.models.regnet import RegNetX160
+from kimm._src.models.regnet import RegNetX320
+from kimm._src.models.regnet import RegNetY002
+from kimm._src.models.regnet import RegNetY004
+from kimm._src.models.regnet import RegNetY006
+from kimm._src.models.regnet import RegNetY008
+from kimm._src.models.regnet import RegNetY016
+from kimm._src.models.regnet import RegNetY032
+from kimm._src.models.regnet import RegNetY040
+from kimm._src.models.regnet import RegNetY064
+from kimm._src.models.regnet import RegNetY080
+from kimm._src.models.regnet import RegNetY120
+from kimm._src.models.regnet import RegNetY160
+from kimm._src.models.regnet import RegNetY320
diff --git a/kimm/models/repvgg/__init__.py b/kimm/models/repvgg/__init__.py
new file mode 100644
index 0000000..0cb0e54
--- /dev/null
+++ b/kimm/models/repvgg/__init__.py
@@ -0,0 +1,13 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.repvgg import RepVGGA0
+from kimm._src.models.repvgg import RepVGGA1
+from kimm._src.models.repvgg import RepVGGA2
+from kimm._src.models.repvgg import RepVGGB0
+from kimm._src.models.repvgg import RepVGGB1
+from kimm._src.models.repvgg import RepVGGB2
+from kimm._src.models.repvgg import RepVGGB3
diff --git a/kimm/models/resnet/__init__.py b/kimm/models/resnet/__init__.py
new file mode 100644
index 0000000..00eb07b
--- /dev/null
+++ b/kimm/models/resnet/__init__.py
@@ -0,0 +1,11 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.resnet import ResNet18
+from kimm._src.models.resnet import ResNet34
+from kimm._src.models.resnet import ResNet50
+from kimm._src.models.resnet import ResNet101
+from kimm._src.models.resnet import ResNet152
diff --git a/kimm/models/vgg/__init__.py b/kimm/models/vgg/__init__.py
new file mode 100644
index 0000000..c5b7102
--- /dev/null
+++ b/kimm/models/vgg/__init__.py
@@ -0,0 +1,10 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.vgg import VGG11
+from kimm._src.models.vgg import VGG13
+from kimm._src.models.vgg import VGG16
+from kimm._src.models.vgg import VGG19
diff --git a/kimm/models/vision_transformer/__init__.py b/kimm/models/vision_transformer/__init__.py
new file mode 100644
index 0000000..426415e
--- /dev/null
+++ b/kimm/models/vision_transformer/__init__.py
@@ -0,0 +1,14 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.vision_transformer import VisionTransformerBase16
+from kimm._src.models.vision_transformer import VisionTransformerBase32
+from kimm._src.models.vision_transformer import VisionTransformerLarge16
+from kimm._src.models.vision_transformer import VisionTransformerLarge32
+from kimm._src.models.vision_transformer import VisionTransformerSmall16
+from kimm._src.models.vision_transformer import VisionTransformerSmall32
+from kimm._src.models.vision_transformer import VisionTransformerTiny16
+from kimm._src.models.vision_transformer import VisionTransformerTiny32
diff --git a/kimm/models/xception/__init__.py b/kimm/models/xception/__init__.py
new file mode 100644
index 0000000..e632631
--- /dev/null
+++ b/kimm/models/xception/__init__.py
@@ -0,0 +1,7 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.models.xception import Xception
diff --git a/kimm/timm_utils/__init__.py b/kimm/timm_utils/__init__.py
new file mode 100644
index 0000000..0818b61
--- /dev/null
+++ b/kimm/timm_utils/__init__.py
@@ -0,0 +1,10 @@
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.utils.timm_utils import assign_weights
+from kimm._src.utils.timm_utils import is_same_weights
+from kimm._src.utils.timm_utils import separate_keras_weights
+from kimm._src.utils.timm_utils import separate_torch_state_dict
diff --git a/kimm/utils/__init__.py b/kimm/utils/__init__.py
index 1b6830e..52fc9aa 100644
--- a/kimm/utils/__init__.py
+++ b/kimm/utils/__init__.py
@@ -1,6 +1,8 @@
-from kimm.utils.make_divisble import make_divisible
-from kimm.utils.model_registry import add_model_to_registry
-from kimm.utils.model_utils import get_reparameterized_model
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+"""DO NOT EDIT.
+
+This file was autogenerated. Do not edit it by hand,
+since your modifications would be overwritten.
+"""
+
+from kimm._src.utils.model_registry import list_models
+from kimm._src.utils.model_utils import get_reparameterized_model
diff --git a/pyproject.toml b/pyproject.toml
index 9552159..f8825de 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -60,6 +60,7 @@ tests = [
"coverage",
# tool
"pre-commit",
+ "namex",
]
examples = ["opencv-python", "matplotlib"]
@@ -86,7 +87,6 @@ exclude = [
]
[tool.ruff.lint.per-file-ignores]
-"./examples/**/*" = ["E402"]
"**/__init__.py" = ["F401"]
[tool.isort]
@@ -96,7 +96,7 @@ known_first_party = ["kimm"]
line_length = 80
[tool.pytest.ini_options]
-addopts = "--durations 10 --cov --cov-report html --cov-report term:skip-covered --cov-report xml"
+addopts = "-vv --durations 10 --cov --cov-report html --cov-report term:skip-covered --cov-report xml"
testpaths = ["kimm"]
filterwarnings = [
"error",
diff --git a/requirements.txt b/requirements.txt
index fbe34cd..ab9adee 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,4 +18,4 @@ torchvision>=0.16.0
jax[cpu]
-keras>=3.0.4
+keras>=3.3.0
diff --git a/shell/api_gen.sh b/shell/api_gen.sh
new file mode 100755
index 0000000..dd5de52
--- /dev/null
+++ b/shell/api_gen.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+set -Eeuo pipefail
+
+base_dir=$(dirname $(dirname $0))
+
+echo "Generating api directory with public APIs..."
+python3 "${base_dir}"/api_gen.py
+
+echo "Formatting api directory..."
+bash "${base_dir}"/shell/format.sh
+
+echo -e "\nAPI generation finish!"
diff --git a/shell/export.sh b/shell/export_models.sh
similarity index 100%
rename from shell/export.sh
rename to shell/export_models.sh
diff --git a/shell/format.sh b/shell/format.sh
index a42b7c7..dcbe6e3 100755
--- a/shell/format.sh
+++ b/shell/format.sh
@@ -4,4 +4,4 @@ set -Eeuo pipefail
base_dir=$(dirname $(dirname $0))
isort --sp "${base_dir}/pyproject.toml" .
black --config "${base_dir}/pyproject.toml" .
-ruff check --config "${base_dir}/pyproject.toml" .
+ruff check --config "${base_dir}/pyproject.toml" --fix .
diff --git a/tools/convert_convmixer_from_timm.py b/tools/convert_convmixer_from_timm.py
index 371e9a0..ca5a9df 100644
--- a/tools/convert_convmixer_from_timm.py
+++ b/tools/convert_convmixer_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import convmixer
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"convmixer_768_32.in1k",
diff --git a/tools/convert_convnext_from_timm.py b/tools/convert_convnext_from_timm.py
index 6da0e5b..7812a44 100644
--- a/tools/convert_convnext_from_timm.py
+++ b/tools/convert_convnext_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import convnext
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"convnext_atto.d2_in1k",
diff --git a/tools/convert_densenet_from_timm.py b/tools/convert_densenet_from_timm.py
index 2ca9a4c..af64f38 100644
--- a/tools/convert_densenet_from_timm.py
+++ b/tools/convert_densenet_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import densenet
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"densenet121.ra_in1k",
diff --git a/tools/convert_efficientnet_from_timm.py b/tools/convert_efficientnet_from_timm.py
index 2e3e8cc..f56ef8d 100644
--- a/tools/convert_efficientnet_from_timm.py
+++ b/tools/convert_efficientnet_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import efficientnet
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"tf_efficientnet_b0.ns_jft_in1k",
diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py
index 795fcf7..bea576b 100644
--- a/tools/convert_ghostnet_from_timm.py
+++ b/tools/convert_ghostnet_from_timm.py
@@ -14,10 +14,10 @@
from kimm.models.ghostnet import GhostNet100V2
from kimm.models.ghostnet import GhostNet130V2
from kimm.models.ghostnet import GhostNet160V2
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"ghostnet_100",
diff --git a/tools/convert_hgnet_from_timm.py b/tools/convert_hgnet_from_timm.py
index 8142c78..e8d82d8 100644
--- a/tools/convert_hgnet_from_timm.py
+++ b/tools/convert_hgnet_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import hgnet
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
# HGNet
diff --git a/tools/convert_inception_next_from_timm.py b/tools/convert_inception_next_from_timm.py
index b8a9c28..399dd80 100644
--- a/tools/convert_inception_next_from_timm.py
+++ b/tools/convert_inception_next_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import inception_next
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"inception_next_tiny.sail_in1k",
diff --git a/tools/convert_inception_v3_from_timm.py b/tools/convert_inception_v3_from_timm.py
index 4851084..ddecbbf 100644
--- a/tools/convert_inception_v3_from_timm.py
+++ b/tools/convert_inception_v3_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import inception_v3
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"inception_v3.gluon_in1k",
diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py
index 28a2208..4b718c6 100644
--- a/tools/convert_mobilenet_v2_from_timm.py
+++ b/tools/convert_mobilenet_v2_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import mobilenet_v2
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"mobilenetv2_050.lamb_in1k",
diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py
index 10b17b6..e0509bd 100644
--- a/tools/convert_mobilenet_v3_from_timm.py
+++ b/tools/convert_mobilenet_v3_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import mobilenet_v3
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"mobilenetv3_small_050.lamb_in1k",
diff --git a/tools/convert_mobileone_from_timm.py b/tools/convert_mobileone_from_timm.py
index 84fe4cc..24d9291 100644
--- a/tools/convert_mobileone_from_timm.py
+++ b/tools/convert_mobileone_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import mobileone
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"mobileone_s0.apple_in1k",
diff --git a/tools/convert_mobilevit_from_timm.py b/tools/convert_mobilevit_from_timm.py
index a90b9b0..d35369e 100644
--- a/tools/convert_mobilevit_from_timm.py
+++ b/tools/convert_mobilevit_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import mobilevit
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"mobilevit_xxs.cvnets_in1k",
diff --git a/tools/convert_regnet_from_timm.py b/tools/convert_regnet_from_timm.py
index a5539b5..5b4a481 100644
--- a/tools/convert_regnet_from_timm.py
+++ b/tools/convert_regnet_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import regnet
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"regnetx_002.pycls_in1k",
diff --git a/tools/convert_repvgg_from_timm.py b/tools/convert_repvgg_from_timm.py
index a47cbef..8454f8b 100644
--- a/tools/convert_repvgg_from_timm.py
+++ b/tools/convert_repvgg_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import repvgg
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"repvgg_a0.rvgg_in1k",
diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py
index 440a3d5..7e0cb4e 100644
--- a/tools/convert_resnet_from_timm.py
+++ b/tools/convert_resnet_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import resnet
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"resnet18.a1_in1k",
diff --git a/tools/convert_vgg_from_timm.py b/tools/convert_vgg_from_timm.py
index b505396..295eb42 100644
--- a/tools/convert_vgg_from_timm.py
+++ b/tools/convert_vgg_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import vgg
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"vgg11_bn.tv_in1k",
diff --git a/tools/convert_vit_from_timm.py b/tools/convert_vit_from_timm.py
index a8b78fd..dfb73d8 100644
--- a/tools/convert_vit_from_timm.py
+++ b/tools/convert_vit_from_timm.py
@@ -11,10 +11,10 @@
import torch
from kimm.models import vision_transformer
-from kimm.utils.timm_utils import assign_weights
-from kimm.utils.timm_utils import is_same_weights
-from kimm.utils.timm_utils import separate_keras_weights
-from kimm.utils.timm_utils import separate_torch_state_dict
+from kimm.timm_utils import assign_weights
+from kimm.timm_utils import is_same_weights
+from kimm.timm_utils import separate_keras_weights
+from kimm.timm_utils import separate_torch_state_dict
timm_model_names = [
"vit_tiny_patch16_384",