From 218bcff1ca78ceb1be9e8fe009e7f4e215000259 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 4 Dec 2025 09:53:19 +0530 Subject: [PATCH 1/4] Adds layer filtering support to the quantization API --- keras/src/models/model.py | 62 ++++++++++++++- keras/src/quantizers/gptq_config.py | 8 ++ keras/src/quantizers/gptq_core.py | 117 +++++++--------------------- keras/src/quantizers/gptq_test.py | 63 +++++++++++++++ 4 files changed, 158 insertions(+), 92 deletions(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index e8fa6415b103..0caab1e6b51d 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -422,7 +422,25 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): **kwargs, ) - def quantize(self, mode, config=None, **kwargs): + def get_quantization_structure(self, mode): + """Returns the quantization structure for the model. + + This method is intended to be overridden by model authors to provide + topology information required for structure-aware quantization modes + like 'gptq'. + + Args: + mode: The quantization mode. + + Returns: + A dictionary describing the topology, e.g.: + `{'pre_block_layers': [list], 'sequential_blocks': [list]}` + or `None` if the mode does not require structure or is not + supported. + """ + return None + + def quantize(self, mode, config=None, filters=None, **kwargs): """Quantize the weights of the model. Note that the model must be built first before calling this method. @@ -430,9 +448,16 @@ def quantize(self, mode, config=None, **kwargs): will be skipped if the layer doesn't implement the function. Args: - mode: The mode of the quantization. Only 'int8' is supported at this - time. + mode: The mode of the quantization. Only 'int8' and 'gptq' are + supported at this time. + config: The configuration for the quantization. + filters: Optional filters to apply to the quantization. + Can be a regex string or a callable. + For 'gptq' mode, this filters the `sequential_blocks`. + For other modes, this filters the layers to be quantized. """ + import re + from keras.src.dtype_policies import QUANTIZATION_MODES # Validate inputs. @@ -449,6 +474,13 @@ def quantize(self, mode, config=None, **kwargs): f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}" ) + if filters is not None: + if not isinstance(filters, (str, typing.Callable)): + raise ValueError( + "The `filters` argument must be a regex string or a " + f"callable. Received: {type(filters)}" + ) + if mode == "gptq": if not isinstance(config, GPTQConfig): raise ValueError( @@ -464,6 +496,15 @@ def quantize(self, mode, config=None, **kwargs): graph_modified = False for layer in self._flatten_layers(): + # Apply filters + if filters is not None: + if isinstance(filters, str): + if not re.search(filters, layer.name): + continue + elif callable(filters): + if not filters(layer): + continue + if len(list(layer._flatten_layers())) == 1: try: layer.quantize(mode, type_check=type_check, config=config) @@ -474,7 +515,20 @@ def quantize(self, mode, config=None, **kwargs): pass if mode == "gptq": - gptq_quantize(self, config) + # Resolve structure + structure = config.layer_structure + if structure is None: + structure = self.get_quantization_structure(mode) + + if structure is None: + raise ValueError( + "For 'gptq' mode, a valid quantization structure must be " + "provided either via `config.layer_structure` or by " + "overriding `model.get_quantization_structure(mode)`. " + "The structure should be a dictionary with keys " + "'pre_block_layers' and 'sequential_blocks'." + ) + gptq_quantize(self, config, structure) # If any layer was changed, we must rebuild the execution functions. if graph_modified: diff --git a/keras/src/quantizers/gptq_config.py b/keras/src/quantizers/gptq_config.py index eaf9434ee192..d9061c3cb1d9 100644 --- a/keras/src/quantizers/gptq_config.py +++ b/keras/src/quantizers/gptq_config.py @@ -131,6 +131,12 @@ class GPTQConfig: activation_order: (bool, optional) If `True`, reorders weight columns based on activation magnitude, which can improve quantization accuracy. Defaults to `False`. + layer_structure: (dict, optional) A dictionary defining the model's + quantization structure. It should contain: + - "pre_block_layers": list of layers to run before the first block. + - "sequential_blocks": list of blocks to be quantized sequentially. + If not provided, the model must implement + `get_quantization_structure`. """ def __init__( @@ -146,6 +152,7 @@ def __init__( group_size: int = 128, symmetric: bool = False, activation_order: bool = False, + layer_structure: dict = None, ): if weight_bits not in [2, 3, 4, 8]: raise ValueError( @@ -174,6 +181,7 @@ def __init__( self.group_size = group_size self.symmetric = symmetric self.activation_order = activation_order + self.layer_structure = layer_structure def dtype_policy_string(self): """Returns the dtype policy string for this configuration. diff --git a/keras/src/quantizers/gptq_core.py b/keras/src/quantizers/gptq_core.py index b97e929e37d2..1f715371daa0 100644 --- a/keras/src/quantizers/gptq_core.py +++ b/keras/src/quantizers/gptq_core.py @@ -10,7 +10,6 @@ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap from keras.src.layers import Dense from keras.src.layers import EinsumDense -from keras.src.layers import Embedding from keras.src.quantizers.gptq import GPTQ from keras.src.quantizers.gptq_config import GPTQConfig @@ -193,38 +192,6 @@ def get_dataloader( return samples.astype(np.int32)[:, None, :] -def _get_backbone_layers(model): - """Extract embedding and transformer layers from a KerasHub model.""" - backbone = model.backbone - if not hasattr(backbone, "transformer_layers"): - raise ValueError( - "The model's backbone does not have a 'transformer_layers' " - "attribute. Please ensure you are using a standard KerasHub " - "transformer model." - ) - transformer_blocks = backbone.transformer_layers - - embedding_layer = None - if hasattr(backbone, "token_embedding"): - embedding_layer = backbone.token_embedding - elif hasattr(backbone, "embedding"): - embedding_layer = backbone.embedding - return embedding_layer, transformer_blocks - - -def _get_custom_layers(model): - """Heuristic for extracting embedding + transformer blocks from a custom - model.""" - embedding_layer = None - transformer_blocks = [] - for layer in model.layers: - if isinstance(layer, Embedding) and embedding_layer is None: - embedding_layer = layer - elif getattr(layer, "_layers", None): # container-like block - transformer_blocks.append(layer) - return embedding_layer, transformer_blocks - - def find_layers_in_block(block): """ Finds all Dense and EinsumDense layers in a transformer block. @@ -242,72 +209,40 @@ def find_layers_in_block(block): return found_layers -def apply_gptq_layerwise(model, dataloader, config): +def apply_gptq_layerwise(model, dataloader, config, structure): """Applies GPTQ quantization layer-by-layer to a Keras model. - This function is designed to work with common transformer architectures, - like those provided by KerasHub. It automatically discovers the model's - structure by first looking for the standard format: a `model.backbone` - attribute that contains a `transformer_layers` list. - - If a standard backbone is not found, it falls back to a heuristic for - custom models, where it assumes the first `keras.layers.Embedding` layer - is the input embedding and any subsequent container layers are the - transformer blocks to be quantized. - - The core logic operates as follows: - 1. It automatically detects the model's structure, identifying the main - embedding layer and a sequence of transformer blocks. - 2. It processes the model sequentially, one block at a time. For each - block, it uses temporary hooks to capture the input activations of - each target layer during a forward pass with the calibration data. - 3. These captured activations are used to compute the Hessian matrix for - each layer's weights. - 4. The GPTQ algorithm is then applied to each layer to find the optimal - quantized weights that minimize the error introduced. - 5. The output activations from the current block are then used as the - input for the next block, ensuring that quantization errors are - accounted for throughout the model. + This function uses the provided `structure` to identify pre-quantization + layers and sequential blocks. Args: - model: The Keras model instance to be quantized. The function will - attempt to automatically discover its structure. - dataloader: An iterable providing calibration data. Each item should - be a batch of token IDs suitable for the model's embedding layer. + model: The Keras model instance to be quantized. + dataloader: An iterable providing calibration data. config: A GPTQConfiguration object. - - Raises: - ValueError: If the function cannot automatically find an embedding - layer or any transformer-like blocks to quantize within the model. + structure: A dictionary with keys "pre_block_layers" and + "sequential_blocks". """ num_samples = config.num_samples logging.info("Starting model quantization...") - embedding_layer = None - transformer_blocks = [] - if hasattr(model, "backbone"): - logging.info("Detected KerasHub model structure.") - embedding_layer, transformer_blocks = _get_backbone_layers(model) - else: - logging.info("Detected custom model structure.") - embedding_layer, transformer_blocks = _get_custom_layers(model) - if embedding_layer is None: - raise ValueError( - "Could not automatically find an embedding layer in the model." - ) + pre_layers = structure.get("pre_block_layers", []) + transformer_blocks = structure.get("sequential_blocks", []) + if not transformer_blocks: raise ValueError( - "Could not automatically find any transformer-like blocks to " - "quantize." + "No sequential blocks found in the provided structure to quantize." ) - # Initial inputs are the outputs of the token embedding layer - inputs = [ - embedding_layer(ops.convert_to_tensor(batch, dtype="int32")) - for batch in dataloader - ] + # Initial inputs are the outputs of the pre-block layers + inputs = [] + for batch in dataloader: + batch = ops.convert_to_tensor(batch, dtype="int32") + for layer in pre_layers: + batch = layer(batch) + inputs.append(batch) + num_samples = min(num_samples, len(inputs)) progbar = keras_utils.Progbar(target=len(transformer_blocks)) @@ -316,10 +251,16 @@ def apply_gptq_layerwise(model, dataloader, config): logging.info(f"Quantizing Block {block_idx}") sub_layers_map = find_layers_in_block(block) + # Filter out layers that are not quantized with GPTQ + sub_layers_map = { + name: layer + for name, layer in sub_layers_map.items() + if getattr(layer, "quantization_mode", None) == "gptq" + } + if not sub_layers_map: logging.info( - f" No Dense or EinsumDense layers found in block {block_idx}. " - "Skipping." + f" No quantizable layers found in block {block_idx}. Skipping." ) else: logging.info(f"Found layers: {list(sub_layers_map.keys())}") @@ -357,7 +298,7 @@ def apply_gptq_layerwise(model, dataloader, config): logging.info("Quantization process complete.") -def gptq_quantize(model, config): +def gptq_quantize(model, config, structure): """ Top-level function to quantize a Keras model using GPTQ. """ @@ -376,7 +317,7 @@ def gptq_quantize(model, config): # is now a NumPy array, which can be sliced and reused. calibration_dataloader = dataloader[: config.num_samples] - apply_gptq_layerwise(model, calibration_dataloader, config) + apply_gptq_layerwise(model, calibration_dataloader, config, structure) def get_group_size_for_layer(layer, config): diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index 2c9acb40f28f..9469bd6dbead 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -582,6 +582,22 @@ def test_quantize_gptq_combinations(self, dataset, config): # Baseline logits y_ref = model.predict(x_eval) + # Define structure for the test model + # The model is: Embedding -> SimpleTransformerBlock -> GAP -> Dense + # The SimpleTransformerBlock is the sequential block. + # The Embedding is the pre-block layer. + + # We need to access the layers from the model. + # model.layers[1] is Embedding + # model.layers[2] is SimpleTransformerBlock + embedding_layer = model.layers[1] + transformer_block = model.layers[2] + + layer_structure = { + "pre_block_layers": [embedding_layer], + "sequential_blocks": [transformer_block], + } + base_cfg = dict( dataset=calibration_set, tokenizer=tokenizer, @@ -591,6 +607,7 @@ def test_quantize_gptq_combinations(self, dataset, config): group_size=32, symmetric=False, activation_order=False, + layer_structure=layer_structure, ) gptq_cfg = GPTQConfig(**{**base_cfg, **config}) @@ -625,6 +642,13 @@ def test_quantize_gptq_combinations(self, dataset, config): "expected_exception": ValueError, "error_msg": "only supported for 'gptq'", }, + { + "testcase_name": "gptq_missing_structure", + "mode": "gptq", + "config": GPTQConfig(dataset=["a"], tokenizer=lambda x: x), + "expected_exception": ValueError, + "error_msg": "For 'gptq' mode, a valid quantization structure", + }, ) def test_quantize_scenarios( self, mode, config, expected_exception, error_msg @@ -632,3 +656,42 @@ def test_quantize_scenarios( model = _get_simple_model() with self.assertRaisesRegex(expected_exception, error_msg): model.quantize(mode, config=config) + + def test_gptq_filtering(self): + """Tests that filters argument works for GPTQ.""" + model = _get_sequence_classifier() + tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN) + + # Structure + embedding_layer = model.layers[1] + transformer_block = model.layers[2] + layer_structure = { + "pre_block_layers": [embedding_layer], + "sequential_blocks": [transformer_block], + } + + config = GPTQConfig( + dataset=[np.zeros((1, SEQ_LEN), dtype="int32")], + tokenizer=tokenizer, + layer_structure=layer_structure, + weight_bits=4, + group_size=32, + ) + + target_layer = transformer_block.ffn.layers[0] + + def filter_fn(layer): + return layer.name != target_layer.name + + model.quantize("gptq", config=config, filters=filter_fn) + + # Check that target_layer is NOT quantized + self.assertIsNone(getattr(target_layer, "quantization_mode", None)) + self.assertFalse(hasattr(target_layer, "quantized_kernel")) + + # Check that other dense layers ARE quantized + other_dense = transformer_block.ffn.layers[1] + self.assertEqual( + getattr(other_dense, "quantization_mode", None), "gptq" + ) + self.assertTrue(hasattr(other_dense, "quantized_kernel")) From 5ed530b7686d5cdc097d7a14178b5c5765cb2576 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 4 Dec 2025 10:57:13 +0530 Subject: [PATCH 2/4] multi-filter support --- keras/src/models/model.py | 61 +++++++++++++----------- keras/src/models/model_test.py | 25 ++++++++++ keras/src/quantizers/gptq_config.py | 10 ++-- keras/src/quantizers/gptq_core.py | 66 ++++++++++++++++++++++---- keras/src/quantizers/gptq_core_test.py | 23 ++++++--- keras/src/quantizers/gptq_test.py | 52 +++++++++++++++----- keras/src/quantizers/utils.py | 23 +++++++++ keras/src/quantizers/utils_test.py | 20 ++++++++ 8 files changed, 220 insertions(+), 60 deletions(-) create mode 100644 keras/src/quantizers/utils.py create mode 100644 keras/src/quantizers/utils_test.py diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 0caab1e6b51d..21fd12f4c10e 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -2,6 +2,7 @@ import json import typing import warnings +from collections.abc import Callable from keras.src import backend from keras.src import utils @@ -10,6 +11,7 @@ from keras.src.models.variable_mapping import map_saveable_variables from keras.src.quantizers.gptq_config import GPTQConfig from keras.src.quantizers.gptq_core import gptq_quantize +from keras.src.quantizers.utils import should_quantize_layer from keras.src.saving import saving_api from keras.src.trainers import trainer as base_trainer from keras.src.utils import summary_utils @@ -422,7 +424,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): **kwargs, ) - def get_quantization_structure(self, mode): + def get_quantization_layer_structure(self, mode): """Returns the quantization structure for the model. This method is intended to be overridden by model authors to provide @@ -436,8 +438,13 @@ def get_quantization_structure(self, mode): A dictionary describing the topology, e.g.: `{'pre_block_layers': [list], 'sequential_blocks': [list]}` or `None` if the mode does not require structure or is not - supported. + supported. `'pre_block_layers'` is a list of layers that + the inputs should be passed through, before being passed to + the sequential blocks. For example, inputs to an LLM must + first be passed through an embedding layer, followed by + the transformer. """ + del mode # Unused. return None def quantize(self, mode, config=None, filters=None, **kwargs): @@ -448,15 +455,14 @@ def quantize(self, mode, config=None, filters=None, **kwargs): will be skipped if the layer doesn't implement the function. Args: - mode: The mode of the quantization. Only 'int8' and 'gptq' are - supported at this time. - config: The configuration for the quantization. - filters: Optional filters to apply to the quantization. - Can be a regex string or a callable. - For 'gptq' mode, this filters the `sequential_blocks`. - For other modes, this filters the layers to be quantized. + mode: The mode of the quantization. Supported modes are: 'int4', + 'int8', 'float8', 'gptq'. + config: The configuration object specifying additional + quantization options for supported modes. + filters: Optional filters to apply to the quantization. Can be a + regex string, a list of regex strings, or a callable. Only the + layers which match the filter conditions will be quantized. """ - import re from keras.src.dtype_policies import QUANTIZATION_MODES @@ -475,10 +481,11 @@ def quantize(self, mode, config=None, filters=None, **kwargs): ) if filters is not None: - if not isinstance(filters, (str, typing.Callable)): + if not isinstance(filters, (str, Callable, list, tuple)): raise ValueError( - "The `filters` argument must be a regex string or a " - f"callable. Received: {type(filters)}" + "The `filters` argument must be a regex string, a list of " + "regex strings, or a callable. Received: " + f"{type(filters)}" ) if mode == "gptq": @@ -497,13 +504,8 @@ def quantize(self, mode, config=None, filters=None, **kwargs): graph_modified = False for layer in self._flatten_layers(): # Apply filters - if filters is not None: - if isinstance(filters, str): - if not re.search(filters, layer.name): - continue - elif callable(filters): - if not filters(layer): - continue + if not should_quantize_layer(layer, filters): + continue if len(list(layer._flatten_layers())) == 1: try: @@ -515,20 +517,25 @@ def quantize(self, mode, config=None, filters=None, **kwargs): pass if mode == "gptq": - # Resolve structure - structure = config.layer_structure + # Resolve model structure. + # 1. If quantization_layer_structure is provided inside the config, + # use that. + structure = config.quantization_layer_structure + # 2. If no layer structure is provided in the config, try to fetch + # it using the `get_quantization_layer_structure` hook. if structure is None: - structure = self.get_quantization_structure(mode) + structure = self.get_quantization_layer_structure(mode) if structure is None: raise ValueError( "For 'gptq' mode, a valid quantization structure must be " - "provided either via `config.layer_structure` or by " - "overriding `model.get_quantization_structure(mode)`. " - "The structure should be a dictionary with keys " + "provided either via `config.quantization_layer_structure` " + "or by overriding " + "`model.get_quantization_layer_structure(mode)`. The " + "structure should be a dictionary with keys " "'pre_block_layers' and 'sequential_blocks'." ) - gptq_quantize(self, config, structure) + gptq_quantize(self, config, structure, filters=filters) # If any layer was changed, we must rebuild the execution functions. if graph_modified: diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 4b2b5ce00081..4843b31abf10 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -817,6 +817,31 @@ def test_quantize(self, mode): if backend.backend() == "torch": self.assertLen(list(model.named_parameters()), 16) + @parameterized.named_parameters( + ("regex_string", "dense_1", ["dense_1"]), + ("list_of_regex", ["dense_1", "output"], ["dense_1", "output"]), + ("callable", lambda l: "dense" in l.name, ["dense_1", "dense_2"]), + ) + def test_quantize_with_filters(self, filters, expected_quantized_layers): + mode = "int8" + inputs = layers.Input([3]) + x = layers.Dense(32, name="dense_1")(inputs) + x = layers.Dense(32, name="dense_2")(x) + outputs = layers.Dense(32, name="output")(x) + model = Model(inputs, outputs) + + model.quantize(mode, filters=filters) + + for layer in model._flatten_layers(): + if layer.name in expected_quantized_layers: + self.assertEqual( + layer.dtype_policy.name, f"{mode}_from_float32" + ) + elif isinstance(layer, layers.Dense): + self.assertNotEqual( + layer.dtype_policy.name, f"{mode}_from_float32" + ) + @parameterized.named_parameters( ("int8", "int8"), ("float8", "float8"), diff --git a/keras/src/quantizers/gptq_config.py b/keras/src/quantizers/gptq_config.py index d9061c3cb1d9..edcb465ce4c2 100644 --- a/keras/src/quantizers/gptq_config.py +++ b/keras/src/quantizers/gptq_config.py @@ -131,12 +131,12 @@ class GPTQConfig: activation_order: (bool, optional) If `True`, reorders weight columns based on activation magnitude, which can improve quantization accuracy. Defaults to `False`. - layer_structure: (dict, optional) A dictionary defining the model's - quantization structure. It should contain: + quantization_layer_structure: (dict, optional) A dictionary defining the + model's quantization structure. It should contain: - "pre_block_layers": list of layers to run before the first block. - "sequential_blocks": list of blocks to be quantized sequentially. If not provided, the model must implement - `get_quantization_structure`. + `get_quantization_layer_structure`. """ def __init__( @@ -152,7 +152,7 @@ def __init__( group_size: int = 128, symmetric: bool = False, activation_order: bool = False, - layer_structure: dict = None, + quantization_layer_structure: dict = None, ): if weight_bits not in [2, 3, 4, 8]: raise ValueError( @@ -181,7 +181,7 @@ def __init__( self.group_size = group_size self.symmetric = symmetric self.activation_order = activation_order - self.layer_structure = layer_structure + self.quantization_layer_structure = quantization_layer_structure def dtype_policy_string(self): """Returns the dtype policy string for this configuration. diff --git a/keras/src/quantizers/gptq_core.py b/keras/src/quantizers/gptq_core.py index 1f715371daa0..b2a13a905a11 100644 --- a/keras/src/quantizers/gptq_core.py +++ b/keras/src/quantizers/gptq_core.py @@ -12,6 +12,7 @@ from keras.src.layers import EinsumDense from keras.src.quantizers.gptq import GPTQ from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.utils import should_quantize_layer @contextmanager @@ -209,18 +210,32 @@ def find_layers_in_block(block): return found_layers -def apply_gptq_layerwise(model, dataloader, config, structure): +def apply_gptq_layerwise(model, dataloader, config, structure, filters=None): """Applies GPTQ quantization layer-by-layer to a Keras model. This function uses the provided `structure` to identify pre-quantization layers and sequential blocks. + The core logic operates as follows: + + 1. It processes the model sequentially, one block at a time. For each + block, it uses temporary hooks to capture the input activations of + each target layer during a forward pass with the calibration data. + 2. These captured activations are used to compute the Hessian matrix for + each layer's weights. + 3. The GPTQ algorithm is then applied to each layer to find the optimal + quantized weights that minimize the error introduced. + 4. The output activations from the current block are then used as the + input for the next block, ensuring that quantization errors are + accounted for throughout the model. + Args: model: The Keras model instance to be quantized. dataloader: An iterable providing calibration data. config: A GPTQConfiguration object. structure: A dictionary with keys "pre_block_layers" and "sequential_blocks". + filters: Optional filters to exclude layers from quantization. """ num_samples = config.num_samples @@ -252,11 +267,16 @@ def apply_gptq_layerwise(model, dataloader, config, structure): sub_layers_map = find_layers_in_block(block) # Filter out layers that are not quantized with GPTQ - sub_layers_map = { - name: layer - for name, layer in sub_layers_map.items() - if getattr(layer, "quantization_mode", None) == "gptq" - } + # We also apply the explicit `filters` argument here. + final_sub_layers_map = {} + for name, layer in sub_layers_map.items(): + # 2. Apply explicit filters + if not should_quantize_layer(layer, filters): + continue + + final_sub_layers_map[name] = layer + + sub_layers_map = final_sub_layers_map if not sub_layers_map: logging.info( @@ -298,11 +318,31 @@ def apply_gptq_layerwise(model, dataloader, config, structure): logging.info("Quantization process complete.") -def gptq_quantize(model, config, structure): +def gptq_quantize(model, config, quantization_layer_structure, filters=None): """ - Top-level function to quantize a Keras model using GPTQ. + Quantizes the model using GPTQ. + + Args: + model: The model to be quantized. + config: The GPTQ configuration. + quantization_layer_structure: A dictionary describing the model's layer + structure for quantization. + filters: Optional filters to exclude layers from quantization. """ - logging.info("Starting GPTQ quantization process...") + if config.dataset is None or config.tokenizer is None: + raise ValueError( + "GPTQ quantization requires a dataset and a tokenizer. " + "Please provide them in the `GPTQConfig`." + ) + + if quantization_layer_structure is None: + raise ValueError( + "For 'gptq' mode, a valid quantization structure must be provided " + "either via `config.quantization_layer_structure` or by overriding " + "`model.get_quantization_layer_structure(mode)`. The structure " + "should be a dictionary with keys 'pre_block_layers' and " + "'sequential_blocks'." + ) # Load all data needed from the generator/source in a single call. total_samples_to_request = config.num_samples @@ -317,7 +357,13 @@ def gptq_quantize(model, config, structure): # is now a NumPy array, which can be sliced and reused. calibration_dataloader = dataloader[: config.num_samples] - apply_gptq_layerwise(model, calibration_dataloader, config, structure) + apply_gptq_layerwise( + model, + calibration_dataloader, + config, + quantization_layer_structure, + filters=filters, + ) def get_group_size_for_layer(layer, config): diff --git a/keras/src/quantizers/gptq_core_test.py b/keras/src/quantizers/gptq_core_test.py index 5ac0ecba3787..9bd06528c9ab 100644 --- a/keras/src/quantizers/gptq_core_test.py +++ b/keras/src/quantizers/gptq_core_test.py @@ -8,6 +8,7 @@ from keras.src import testing from keras.src.quantizers.gptq_config import GPTQConfig from keras.src.quantizers.gptq_core import get_dataloader +from keras.src.quantizers.gptq_core import gptq_quantize VOCAB_SIZE = 100 @@ -269,8 +270,17 @@ def test_apply_gptq_on_multi_block_model(self): ] ) model.build(input_shape=(None, 10)) + + layer_structure = { + "pre_block_layers": [model.layers[0]], + "sequential_blocks": [model.layers[1], model.layers[2]], + } + config = GPTQConfig( - dataset=["test data"], tokenizer=MockTokenizer(), group_size=32 + dataset=["test data"], + tokenizer=MockTokenizer(), + group_size=32, + quantization_layer_structure=layer_structure, ) model.quantize("gptq", config=config) @@ -278,24 +288,24 @@ def test_apply_gptq_on_multi_block_model(self): ( "no_embedding_layer", models.Sequential([layers.Dense(10)]), - "Could not automatically find an embedding layer", + "For 'gptq' mode, a valid quantization structure must be provided", ), ( "no_transformer_blocks", models.Sequential( [layers.Embedding(VOCAB_SIZE, 10), layers.Dense(10)] ), - "Could not automatically find any transformer-like blocks", + "For 'gptq' mode, a valid quantization structure must be provided", ), ( "backbone_no_layers", _get_model_with_backbone(has_transformer_layers=False), - "Could not automatically find any transformer-like blocks", + "For 'gptq' mode, a valid quantization structure must be provided", ), ( "backbone_no_embedding", _get_model_with_backbone(embedding_name="wrong_name"), - "Could not automatically find an embedding layer in the model", + "For 'gptq' mode, a valid quantization structure must be provided", ), ) def test_apply_gptq_with_unsupported_architectures( @@ -308,4 +318,5 @@ def test_apply_gptq_with_unsupported_architectures( config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer()) with self.assertRaisesRegex(ValueError, error_message): - model.quantize("gptq", config=config) + # We pass None as structure to trigger the error + gptq_quantize(model, config, quantization_layer_structure=None) diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index 9469bd6dbead..d6fe0048ac3f 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -582,14 +582,6 @@ def test_quantize_gptq_combinations(self, dataset, config): # Baseline logits y_ref = model.predict(x_eval) - # Define structure for the test model - # The model is: Embedding -> SimpleTransformerBlock -> GAP -> Dense - # The SimpleTransformerBlock is the sequential block. - # The Embedding is the pre-block layer. - - # We need to access the layers from the model. - # model.layers[1] is Embedding - # model.layers[2] is SimpleTransformerBlock embedding_layer = model.layers[1] transformer_block = model.layers[2] @@ -607,7 +599,7 @@ def test_quantize_gptq_combinations(self, dataset, config): group_size=32, symmetric=False, activation_order=False, - layer_structure=layer_structure, + quantization_layer_structure=layer_structure, ) gptq_cfg = GPTQConfig(**{**base_cfg, **config}) @@ -673,7 +665,7 @@ def test_gptq_filtering(self): config = GPTQConfig( dataset=[np.zeros((1, SEQ_LEN), dtype="int32")], tokenizer=tokenizer, - layer_structure=layer_structure, + quantization_layer_structure=layer_structure, weight_bits=4, group_size=32, ) @@ -685,13 +677,49 @@ def filter_fn(layer): model.quantize("gptq", config=config, filters=filter_fn) - # Check that target_layer is NOT quantized + # Check that target_layer is NOT quantized. self.assertIsNone(getattr(target_layer, "quantization_mode", None)) self.assertFalse(hasattr(target_layer, "quantized_kernel")) - # Check that other dense layers ARE quantized + # Check that other dense layers ARE quantized. other_dense = transformer_block.ffn.layers[1] self.assertEqual( getattr(other_dense, "quantization_mode", None), "gptq" ) self.assertTrue(hasattr(other_dense, "quantized_kernel")) + + def test_gptq_multi_filtering(self): + """Tests that list of regex filters works for GPTQ.""" + model = _get_sequence_classifier() + tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN) + + embedding_layer = model.layers[1] + transformer_block = model.layers[2] + layer_structure = { + "pre_block_layers": [embedding_layer], + "sequential_blocks": [transformer_block], + } + + config = GPTQConfig( + dataset=[np.zeros((1, SEQ_LEN), dtype="int32")], + tokenizer=tokenizer, + quantization_layer_structure=layer_structure, + weight_bits=4, + group_size=32, + ) + + layer0 = transformer_block.ffn.layers[0] + layer1 = transformer_block.ffn.layers[1] + + # We want to quantize only layer0. + filters = [f"^{layer0.name}$"] + + model.quantize("gptq", config=config, filters=filters) + + # Check that layer0 is quantized. + self.assertEqual(getattr(layer0, "quantization_mode", None), "gptq") + self.assertTrue(hasattr(layer0, "quantized_kernel")) + + # Check that layer1 is not quantized. + self.assertIsNone(getattr(layer1, "quantization_mode", None)) + self.assertFalse(hasattr(layer1, "quantized_kernel")) diff --git a/keras/src/quantizers/utils.py b/keras/src/quantizers/utils.py new file mode 100644 index 000000000000..196ed6642909 --- /dev/null +++ b/keras/src/quantizers/utils.py @@ -0,0 +1,23 @@ +import re + + +def should_quantize_layer(layer, filters): + """Determines if a layer should be quantized based on filters. + + Args: + layer: The layer to check. + filters: A regex string, a list of regex strings, or a callable. + If None, returns True. + + Returns: + True if the layer should be quantized, False otherwise. + """ + if filters is None: + return True + if isinstance(filters, str): + return bool(re.search(filters, layer.name)) + if isinstance(filters, (list, tuple)): + return any(re.search(pat, layer.name) for pat in filters) + if callable(filters): + return filters(layer) + return True diff --git a/keras/src/quantizers/utils_test.py b/keras/src/quantizers/utils_test.py new file mode 100644 index 000000000000..0ded897b9c52 --- /dev/null +++ b/keras/src/quantizers/utils_test.py @@ -0,0 +1,20 @@ +from absl.testing import parameterized + +from keras.src import layers +from keras.src import testing +from keras.src.quantizers import utils + + +class UtilsTest(testing.TestCase): + @parameterized.named_parameters( + ("none_filter", None, "dense", True), + ("regex_match", "dense", "dense_1", True), + ("regex_no_match", "conv", "dense_1", False), + ("list_match", ["dense", "conv"], "dense_1", True), + ("list_no_match", ["conv", "pool"], "dense_1", False), + ("callable_match", lambda l: "dense" in l.name, "dense_1", True), + ("callable_no_match", lambda l: "conv" in l.name, "dense_1", False), + ) + def test_should_quantize_layer(self, filters, layer_name, expected): + layer = layers.Layer(name=layer_name) + self.assertEqual(utils.should_quantize_layer(layer, filters), expected) From 2ec0e20205341c3f4bfc3b4fa28c90bb96c166b9 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 4 Dec 2025 17:08:55 +0530 Subject: [PATCH 3/4] removes unused params comment --- keras/src/models/model.py | 2 +- keras/src/quantizers/gptq_core.py | 11 ++++++----- keras/src/quantizers/gptq_core_test.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 21fd12f4c10e..a1ebcd385337 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -535,7 +535,7 @@ def quantize(self, mode, config=None, filters=None, **kwargs): "structure should be a dictionary with keys " "'pre_block_layers' and 'sequential_blocks'." ) - gptq_quantize(self, config, structure, filters=filters) + gptq_quantize(config, structure, filters=filters) # If any layer was changed, we must rebuild the execution functions. if graph_modified: diff --git a/keras/src/quantizers/gptq_core.py b/keras/src/quantizers/gptq_core.py index b2a13a905a11..d6e86da74775 100644 --- a/keras/src/quantizers/gptq_core.py +++ b/keras/src/quantizers/gptq_core.py @@ -210,7 +210,7 @@ def find_layers_in_block(block): return found_layers -def apply_gptq_layerwise(model, dataloader, config, structure, filters=None): +def apply_gptq_layerwise(dataloader, config, structure, filters=None): """Applies GPTQ quantization layer-by-layer to a Keras model. This function uses the provided `structure` to identify pre-quantization @@ -230,12 +230,15 @@ def apply_gptq_layerwise(model, dataloader, config, structure, filters=None): accounted for throughout the model. Args: - model: The Keras model instance to be quantized. dataloader: An iterable providing calibration data. config: A GPTQConfiguration object. structure: A dictionary with keys "pre_block_layers" and "sequential_blocks". filters: Optional filters to exclude layers from quantization. + + Raises: + ValueError: If the function cannot automatically find an embedding + layer or any transformer-like blocks to quantize within the model. """ num_samples = config.num_samples @@ -318,12 +321,11 @@ def apply_gptq_layerwise(model, dataloader, config, structure, filters=None): logging.info("Quantization process complete.") -def gptq_quantize(model, config, quantization_layer_structure, filters=None): +def gptq_quantize(config, quantization_layer_structure, filters=None): """ Quantizes the model using GPTQ. Args: - model: The model to be quantized. config: The GPTQ configuration. quantization_layer_structure: A dictionary describing the model's layer structure for quantization. @@ -358,7 +360,6 @@ def gptq_quantize(model, config, quantization_layer_structure, filters=None): calibration_dataloader = dataloader[: config.num_samples] apply_gptq_layerwise( - model, calibration_dataloader, config, quantization_layer_structure, diff --git a/keras/src/quantizers/gptq_core_test.py b/keras/src/quantizers/gptq_core_test.py index 9bd06528c9ab..74315a4f258a 100644 --- a/keras/src/quantizers/gptq_core_test.py +++ b/keras/src/quantizers/gptq_core_test.py @@ -319,4 +319,4 @@ def test_apply_gptq_with_unsupported_architectures( config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer()) with self.assertRaisesRegex(ValueError, error_message): # We pass None as structure to trigger the error - gptq_quantize(model, config, quantization_layer_structure=None) + gptq_quantize(config, quantization_layer_structure=None) From ec7ac8ff6f984e80ebd64aa8cf4aa8b8325b0f5f Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 4 Dec 2025 17:20:58 +0530 Subject: [PATCH 4/4] fix comments --- keras/src/quantizers/gptq_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras/src/quantizers/gptq_core.py b/keras/src/quantizers/gptq_core.py index d6e86da74775..7f6ccd2eb1c5 100644 --- a/keras/src/quantizers/gptq_core.py +++ b/keras/src/quantizers/gptq_core.py @@ -270,10 +270,8 @@ def apply_gptq_layerwise(dataloader, config, structure, filters=None): sub_layers_map = find_layers_in_block(block) # Filter out layers that are not quantized with GPTQ - # We also apply the explicit `filters` argument here. final_sub_layers_map = {} for name, layer in sub_layers_map.items(): - # 2. Apply explicit filters if not should_quantize_layer(layer, filters): continue