Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions keras/src/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import typing
import warnings
from collections.abc import Callable

from keras.src import backend
from keras.src import utils
Expand All @@ -10,6 +11,7 @@
from keras.src.models.variable_mapping import map_saveable_variables
from keras.src.quantizers.gptq_config import GPTQConfig
from keras.src.quantizers.gptq_core import gptq_quantize
from keras.src.quantizers.utils import should_quantize_layer
from keras.src.saving import saving_api
from keras.src.trainers import trainer as base_trainer
from keras.src.utils import summary_utils
Expand Down Expand Up @@ -422,17 +424,46 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
**kwargs,
)

def quantize(self, mode, config=None, **kwargs):
def get_quantization_layer_structure(self, mode):
"""Returns the quantization structure for the model.

This method is intended to be overridden by model authors to provide
topology information required for structure-aware quantization modes
like 'gptq'.

Args:
mode: The quantization mode.

Returns:
A dictionary describing the topology, e.g.:
`{'pre_block_layers': [list], 'sequential_blocks': [list]}`
or `None` if the mode does not require structure or is not
supported. `'pre_block_layers'` is a list of layers that
the inputs should be passed through, before being passed to
the sequential blocks. For example, inputs to an LLM must
first be passed through an embedding layer, followed by
the transformer.
"""
del mode # Unused.
return None

def quantize(self, mode, config=None, filters=None, **kwargs):
"""Quantize the weights of the model.

Note that the model must be built first before calling this method.
`quantize` will recursively call `quantize(mode)` in all layers and
will be skipped if the layer doesn't implement the function.

Args:
mode: The mode of the quantization. Only 'int8' is supported at this
time.
mode: The mode of the quantization. Supported modes are: 'int4',
'int8', 'float8', 'gptq'.
config: The configuration object specifying additional
quantization options for supported modes.
filters: Optional filters to apply to the quantization. Can be a
regex string, a list of regex strings, or a callable. Only the
layers which match the filter conditions will be quantized.
"""

from keras.src.dtype_policies import QUANTIZATION_MODES

# Validate inputs.
Expand All @@ -449,6 +480,14 @@ def quantize(self, mode, config=None, **kwargs):
f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
)

if filters is not None:
if not isinstance(filters, (str, Callable, list, tuple)):
raise ValueError(
"The `filters` argument must be a regex string, a list of "
"regex strings, or a callable. Received: "
f"{type(filters)}"
)

if mode == "gptq":
if not isinstance(config, GPTQConfig):
raise ValueError(
Expand All @@ -464,6 +503,10 @@ def quantize(self, mode, config=None, **kwargs):

graph_modified = False
for layer in self._flatten_layers():
# Apply filters
if not should_quantize_layer(layer, filters):
continue

if len(list(layer._flatten_layers())) == 1:
try:
layer.quantize(mode, type_check=type_check, config=config)
Expand All @@ -474,7 +517,25 @@ def quantize(self, mode, config=None, **kwargs):
pass

if mode == "gptq":
gptq_quantize(self, config)
# Resolve model structure.
# 1. If quantization_layer_structure is provided inside the config,
# use that.
structure = config.quantization_layer_structure
# 2. If no layer structure is provided in the config, try to fetch
# it using the `get_quantization_layer_structure` hook.
if structure is None:
structure = self.get_quantization_layer_structure(mode)

if structure is None:
raise ValueError(
"For 'gptq' mode, a valid quantization structure must be "
"provided either via `config.quantization_layer_structure` "
"or by overriding "
"`model.get_quantization_layer_structure(mode)`. The "
"structure should be a dictionary with keys "
"'pre_block_layers' and 'sequential_blocks'."
)
gptq_quantize(config, structure, filters=filters)

# If any layer was changed, we must rebuild the execution functions.
if graph_modified:
Expand Down
25 changes: 25 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,31 @@ def test_quantize(self, mode):
if backend.backend() == "torch":
self.assertLen(list(model.named_parameters()), 16)

@parameterized.named_parameters(
("regex_string", "dense_1", ["dense_1"]),
("list_of_regex", ["dense_1", "output"], ["dense_1", "output"]),
("callable", lambda l: "dense" in l.name, ["dense_1", "dense_2"]),
)
def test_quantize_with_filters(self, filters, expected_quantized_layers):
mode = "int8"
inputs = layers.Input([3])
x = layers.Dense(32, name="dense_1")(inputs)
x = layers.Dense(32, name="dense_2")(x)
outputs = layers.Dense(32, name="output")(x)
model = Model(inputs, outputs)

model.quantize(mode, filters=filters)

for layer in model._flatten_layers():
if layer.name in expected_quantized_layers:
self.assertEqual(
layer.dtype_policy.name, f"{mode}_from_float32"
)
elif isinstance(layer, layers.Dense):
self.assertNotEqual(
layer.dtype_policy.name, f"{mode}_from_float32"
)

@parameterized.named_parameters(
("int8", "int8"),
("float8", "float8"),
Expand Down
8 changes: 8 additions & 0 deletions keras/src/quantizers/gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class GPTQConfig:
activation_order: (bool, optional) If `True`, reorders weight columns
based on activation magnitude, which can improve quantization
accuracy. Defaults to `False`.
quantization_layer_structure: (dict, optional) A dictionary defining the
model's quantization structure. It should contain:
- "pre_block_layers": list of layers to run before the first block.
- "sequential_blocks": list of blocks to be quantized sequentially.
If not provided, the model must implement
`get_quantization_layer_structure`.
"""

def __init__(
Expand All @@ -146,6 +152,7 @@ def __init__(
group_size: int = 128,
symmetric: bool = False,
activation_order: bool = False,
quantization_layer_structure: dict = None,
):
if weight_bits not in [2, 3, 4, 8]:
raise ValueError(
Expand Down Expand Up @@ -174,6 +181,7 @@ def __init__(
self.group_size = group_size
self.symmetric = symmetric
self.activation_order = activation_order
self.quantization_layer_structure = quantization_layer_structure

def dtype_policy_string(self):
"""Returns the dtype policy string for this configuration.
Expand Down
142 changes: 64 additions & 78 deletions keras/src/quantizers/gptq_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
from keras.src.layers import Dense
from keras.src.layers import EinsumDense
from keras.src.layers import Embedding
from keras.src.quantizers.gptq import GPTQ
from keras.src.quantizers.gptq_config import GPTQConfig
from keras.src.quantizers.utils import should_quantize_layer


@contextmanager
Expand Down Expand Up @@ -193,38 +193,6 @@ def get_dataloader(
return samples.astype(np.int32)[:, None, :]


def _get_backbone_layers(model):
"""Extract embedding and transformer layers from a KerasHub model."""
backbone = model.backbone
if not hasattr(backbone, "transformer_layers"):
raise ValueError(
"The model's backbone does not have a 'transformer_layers' "
"attribute. Please ensure you are using a standard KerasHub "
"transformer model."
)
transformer_blocks = backbone.transformer_layers

embedding_layer = None
if hasattr(backbone, "token_embedding"):
embedding_layer = backbone.token_embedding
elif hasattr(backbone, "embedding"):
embedding_layer = backbone.embedding
return embedding_layer, transformer_blocks


def _get_custom_layers(model):
"""Heuristic for extracting embedding + transformer blocks from a custom
model."""
embedding_layer = None
transformer_blocks = []
for layer in model.layers:
if isinstance(layer, Embedding) and embedding_layer is None:
embedding_layer = layer
elif getattr(layer, "_layers", None): # container-like block
transformer_blocks.append(layer)
return embedding_layer, transformer_blocks


def find_layers_in_block(block):
"""
Finds all Dense and EinsumDense layers in a transformer block.
Expand All @@ -242,39 +210,31 @@ def find_layers_in_block(block):
return found_layers


def apply_gptq_layerwise(model, dataloader, config):
def apply_gptq_layerwise(dataloader, config, structure, filters=None):
"""Applies GPTQ quantization layer-by-layer to a Keras model.

This function is designed to work with common transformer architectures,
like those provided by KerasHub. It automatically discovers the model's
structure by first looking for the standard format: a `model.backbone`
attribute that contains a `transformer_layers` list.

If a standard backbone is not found, it falls back to a heuristic for
custom models, where it assumes the first `keras.layers.Embedding` layer
is the input embedding and any subsequent container layers are the
transformer blocks to be quantized.
This function uses the provided `structure` to identify pre-quantization
layers and sequential blocks.

The core logic operates as follows:
1. It automatically detects the model's structure, identifying the main
embedding layer and a sequence of transformer blocks.
2. It processes the model sequentially, one block at a time. For each

1. It processes the model sequentially, one block at a time. For each
block, it uses temporary hooks to capture the input activations of
each target layer during a forward pass with the calibration data.
3. These captured activations are used to compute the Hessian matrix for
2. These captured activations are used to compute the Hessian matrix for
each layer's weights.
4. The GPTQ algorithm is then applied to each layer to find the optimal
3. The GPTQ algorithm is then applied to each layer to find the optimal
quantized weights that minimize the error introduced.
5. The output activations from the current block are then used as the
4. The output activations from the current block are then used as the
input for the next block, ensuring that quantization errors are
accounted for throughout the model.

Args:
model: The Keras model instance to be quantized. The function will
attempt to automatically discover its structure.
dataloader: An iterable providing calibration data. Each item should
be a batch of token IDs suitable for the model's embedding layer.
dataloader: An iterable providing calibration data.
config: A GPTQConfiguration object.
structure: A dictionary with keys "pre_block_layers" and
"sequential_blocks".
filters: Optional filters to exclude layers from quantization.

Raises:
ValueError: If the function cannot automatically find an embedding
Expand All @@ -284,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
num_samples = config.num_samples

logging.info("Starting model quantization...")
embedding_layer = None
transformer_blocks = []
if hasattr(model, "backbone"):
logging.info("Detected KerasHub model structure.")
embedding_layer, transformer_blocks = _get_backbone_layers(model)
else:
logging.info("Detected custom model structure.")
embedding_layer, transformer_blocks = _get_custom_layers(model)

if embedding_layer is None:
raise ValueError(
"Could not automatically find an embedding layer in the model."
)
pre_layers = structure.get("pre_block_layers", [])
transformer_blocks = structure.get("sequential_blocks", [])

if not transformer_blocks:
raise ValueError(
"Could not automatically find any transformer-like blocks to "
"quantize."
"No sequential blocks found in the provided structure to quantize."
)

# Initial inputs are the outputs of the token embedding layer
inputs = [
embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
for batch in dataloader
]
# Initial inputs are the outputs of the pre-block layers
inputs = []
for batch in dataloader:
batch = ops.convert_to_tensor(batch, dtype="int32")
for layer in pre_layers:
batch = layer(batch)
inputs.append(batch)

num_samples = min(num_samples, len(inputs))

progbar = keras_utils.Progbar(target=len(transformer_blocks))
Expand All @@ -316,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
logging.info(f"Quantizing Block {block_idx}")
sub_layers_map = find_layers_in_block(block)

# Filter out layers that are not quantized with GPTQ
final_sub_layers_map = {}
for name, layer in sub_layers_map.items():
if not should_quantize_layer(layer, filters):
continue

final_sub_layers_map[name] = layer

sub_layers_map = final_sub_layers_map

if not sub_layers_map:
logging.info(
f" No Dense or EinsumDense layers found in block {block_idx}. "
"Skipping."
f" No quantizable layers found in block {block_idx}. Skipping."
)
else:
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
Expand Down Expand Up @@ -357,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
logging.info("Quantization process complete.")


def gptq_quantize(model, config):
def gptq_quantize(config, quantization_layer_structure, filters=None):
"""
Top-level function to quantize a Keras model using GPTQ.
Quantizes the model using GPTQ.

Args:
config: The GPTQ configuration.
quantization_layer_structure: A dictionary describing the model's layer
structure for quantization.
filters: Optional filters to exclude layers from quantization.
"""
logging.info("Starting GPTQ quantization process...")
if config.dataset is None or config.tokenizer is None:
raise ValueError(
"GPTQ quantization requires a dataset and a tokenizer. "
"Please provide them in the `GPTQConfig`."
)

if quantization_layer_structure is None:
raise ValueError(
"For 'gptq' mode, a valid quantization structure must be provided "
"either via `config.quantization_layer_structure` or by overriding "
"`model.get_quantization_layer_structure(mode)`. The structure "
"should be a dictionary with keys 'pre_block_layers' and "
"'sequential_blocks'."
)

# Load all data needed from the generator/source in a single call.
total_samples_to_request = config.num_samples
Expand All @@ -376,7 +357,12 @@ def gptq_quantize(model, config):
# is now a NumPy array, which can be sliced and reused.
calibration_dataloader = dataloader[: config.num_samples]

apply_gptq_layerwise(model, calibration_dataloader, config)
apply_gptq_layerwise(
calibration_dataloader,
config,
quantization_layer_structure,
filters=filters,
)


def get_group_size_for_layer(layer, config):
Expand Down
Loading