Skip to content

Commit 9d1d650

Browse files
Introduces layer filtering for quantization and fixes GPTQ dependency inversion (#21894)
* Adds layer filtering support to the quantization API * multi-filter support * removes unused params comment * fix comments
1 parent 5e40ca0 commit 9d1d650

File tree

8 files changed

+313
-88
lines changed

8 files changed

+313
-88
lines changed

keras/src/models/model.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import typing
44
import warnings
5+
from collections.abc import Callable
56

67
from keras.src import backend
78
from keras.src import utils
@@ -10,6 +11,7 @@
1011
from keras.src.models.variable_mapping import map_saveable_variables
1112
from keras.src.quantizers.gptq_config import GPTQConfig
1213
from keras.src.quantizers.gptq_core import gptq_quantize
14+
from keras.src.quantizers.utils import should_quantize_layer
1315
from keras.src.saving import saving_api
1416
from keras.src.trainers import trainer as base_trainer
1517
from keras.src.utils import summary_utils
@@ -422,17 +424,46 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
422424
**kwargs,
423425
)
424426

425-
def quantize(self, mode, config=None, **kwargs):
427+
def get_quantization_layer_structure(self, mode):
428+
"""Returns the quantization structure for the model.
429+
430+
This method is intended to be overridden by model authors to provide
431+
topology information required for structure-aware quantization modes
432+
like 'gptq'.
433+
434+
Args:
435+
mode: The quantization mode.
436+
437+
Returns:
438+
A dictionary describing the topology, e.g.:
439+
`{'pre_block_layers': [list], 'sequential_blocks': [list]}`
440+
or `None` if the mode does not require structure or is not
441+
supported. `'pre_block_layers'` is a list of layers that
442+
the inputs should be passed through, before being passed to
443+
the sequential blocks. For example, inputs to an LLM must
444+
first be passed through an embedding layer, followed by
445+
the transformer.
446+
"""
447+
del mode # Unused.
448+
return None
449+
450+
def quantize(self, mode, config=None, filters=None, **kwargs):
426451
"""Quantize the weights of the model.
427452
428453
Note that the model must be built first before calling this method.
429454
`quantize` will recursively call `quantize(mode)` in all layers and
430455
will be skipped if the layer doesn't implement the function.
431456
432457
Args:
433-
mode: The mode of the quantization. Only 'int8' is supported at this
434-
time.
458+
mode: The mode of the quantization. Supported modes are: 'int4',
459+
'int8', 'float8', 'gptq'.
460+
config: The configuration object specifying additional
461+
quantization options for supported modes.
462+
filters: Optional filters to apply to the quantization. Can be a
463+
regex string, a list of regex strings, or a callable. Only the
464+
layers which match the filter conditions will be quantized.
435465
"""
466+
436467
from keras.src.dtype_policies import QUANTIZATION_MODES
437468

438469
# Validate inputs.
@@ -449,6 +480,14 @@ def quantize(self, mode, config=None, **kwargs):
449480
f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
450481
)
451482

483+
if filters is not None:
484+
if not isinstance(filters, (str, Callable, list, tuple)):
485+
raise ValueError(
486+
"The `filters` argument must be a regex string, a list of "
487+
"regex strings, or a callable. Received: "
488+
f"{type(filters)}"
489+
)
490+
452491
if mode == "gptq":
453492
if not isinstance(config, GPTQConfig):
454493
raise ValueError(
@@ -464,6 +503,10 @@ def quantize(self, mode, config=None, **kwargs):
464503

465504
graph_modified = False
466505
for layer in self._flatten_layers():
506+
# Apply filters
507+
if not should_quantize_layer(layer, filters):
508+
continue
509+
467510
if len(list(layer._flatten_layers())) == 1:
468511
try:
469512
layer.quantize(mode, type_check=type_check, config=config)
@@ -474,7 +517,25 @@ def quantize(self, mode, config=None, **kwargs):
474517
pass
475518

476519
if mode == "gptq":
477-
gptq_quantize(self, config)
520+
# Resolve model structure.
521+
# 1. If quantization_layer_structure is provided inside the config,
522+
# use that.
523+
structure = config.quantization_layer_structure
524+
# 2. If no layer structure is provided in the config, try to fetch
525+
# it using the `get_quantization_layer_structure` hook.
526+
if structure is None:
527+
structure = self.get_quantization_layer_structure(mode)
528+
529+
if structure is None:
530+
raise ValueError(
531+
"For 'gptq' mode, a valid quantization structure must be "
532+
"provided either via `config.quantization_layer_structure` "
533+
"or by overriding "
534+
"`model.get_quantization_layer_structure(mode)`. The "
535+
"structure should be a dictionary with keys "
536+
"'pre_block_layers' and 'sequential_blocks'."
537+
)
538+
gptq_quantize(config, structure, filters=filters)
478539

479540
# If any layer was changed, we must rebuild the execution functions.
480541
if graph_modified:

keras/src/models/model_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,31 @@ def test_quantize(self, mode):
817817
if backend.backend() == "torch":
818818
self.assertLen(list(model.named_parameters()), 16)
819819

820+
@parameterized.named_parameters(
821+
("regex_string", "dense_1", ["dense_1"]),
822+
("list_of_regex", ["dense_1", "output"], ["dense_1", "output"]),
823+
("callable", lambda l: "dense" in l.name, ["dense_1", "dense_2"]),
824+
)
825+
def test_quantize_with_filters(self, filters, expected_quantized_layers):
826+
mode = "int8"
827+
inputs = layers.Input([3])
828+
x = layers.Dense(32, name="dense_1")(inputs)
829+
x = layers.Dense(32, name="dense_2")(x)
830+
outputs = layers.Dense(32, name="output")(x)
831+
model = Model(inputs, outputs)
832+
833+
model.quantize(mode, filters=filters)
834+
835+
for layer in model._flatten_layers():
836+
if layer.name in expected_quantized_layers:
837+
self.assertEqual(
838+
layer.dtype_policy.name, f"{mode}_from_float32"
839+
)
840+
elif isinstance(layer, layers.Dense):
841+
self.assertNotEqual(
842+
layer.dtype_policy.name, f"{mode}_from_float32"
843+
)
844+
820845
@parameterized.named_parameters(
821846
("int8", "int8"),
822847
("float8", "float8"),

keras/src/quantizers/gptq_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ class GPTQConfig:
131131
activation_order: (bool, optional) If `True`, reorders weight columns
132132
based on activation magnitude, which can improve quantization
133133
accuracy. Defaults to `False`.
134+
quantization_layer_structure: (dict, optional) A dictionary defining the
135+
model's quantization structure. It should contain:
136+
- "pre_block_layers": list of layers to run before the first block.
137+
- "sequential_blocks": list of blocks to be quantized sequentially.
138+
If not provided, the model must implement
139+
`get_quantization_layer_structure`.
134140
"""
135141

136142
def __init__(
@@ -146,6 +152,7 @@ def __init__(
146152
group_size: int = 128,
147153
symmetric: bool = False,
148154
activation_order: bool = False,
155+
quantization_layer_structure: dict = None,
149156
):
150157
if weight_bits not in [2, 3, 4, 8]:
151158
raise ValueError(
@@ -174,6 +181,7 @@ def __init__(
174181
self.group_size = group_size
175182
self.symmetric = symmetric
176183
self.activation_order = activation_order
184+
self.quantization_layer_structure = quantization_layer_structure
177185

178186
def dtype_policy_string(self):
179187
"""Returns the dtype policy string for this configuration.

keras/src/quantizers/gptq_core.py

Lines changed: 64 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
1111
from keras.src.layers import Dense
1212
from keras.src.layers import EinsumDense
13-
from keras.src.layers import Embedding
1413
from keras.src.quantizers.gptq import GPTQ
1514
from keras.src.quantizers.gptq_config import GPTQConfig
15+
from keras.src.quantizers.utils import should_quantize_layer
1616

1717

1818
@contextmanager
@@ -193,38 +193,6 @@ def get_dataloader(
193193
return samples.astype(np.int32)[:, None, :]
194194

195195

196-
def _get_backbone_layers(model):
197-
"""Extract embedding and transformer layers from a KerasHub model."""
198-
backbone = model.backbone
199-
if not hasattr(backbone, "transformer_layers"):
200-
raise ValueError(
201-
"The model's backbone does not have a 'transformer_layers' "
202-
"attribute. Please ensure you are using a standard KerasHub "
203-
"transformer model."
204-
)
205-
transformer_blocks = backbone.transformer_layers
206-
207-
embedding_layer = None
208-
if hasattr(backbone, "token_embedding"):
209-
embedding_layer = backbone.token_embedding
210-
elif hasattr(backbone, "embedding"):
211-
embedding_layer = backbone.embedding
212-
return embedding_layer, transformer_blocks
213-
214-
215-
def _get_custom_layers(model):
216-
"""Heuristic for extracting embedding + transformer blocks from a custom
217-
model."""
218-
embedding_layer = None
219-
transformer_blocks = []
220-
for layer in model.layers:
221-
if isinstance(layer, Embedding) and embedding_layer is None:
222-
embedding_layer = layer
223-
elif getattr(layer, "_layers", None): # container-like block
224-
transformer_blocks.append(layer)
225-
return embedding_layer, transformer_blocks
226-
227-
228196
def find_layers_in_block(block):
229197
"""
230198
Finds all Dense and EinsumDense layers in a transformer block.
@@ -242,39 +210,31 @@ def find_layers_in_block(block):
242210
return found_layers
243211

244212

245-
def apply_gptq_layerwise(model, dataloader, config):
213+
def apply_gptq_layerwise(dataloader, config, structure, filters=None):
246214
"""Applies GPTQ quantization layer-by-layer to a Keras model.
247215
248-
This function is designed to work with common transformer architectures,
249-
like those provided by KerasHub. It automatically discovers the model's
250-
structure by first looking for the standard format: a `model.backbone`
251-
attribute that contains a `transformer_layers` list.
252-
253-
If a standard backbone is not found, it falls back to a heuristic for
254-
custom models, where it assumes the first `keras.layers.Embedding` layer
255-
is the input embedding and any subsequent container layers are the
256-
transformer blocks to be quantized.
216+
This function uses the provided `structure` to identify pre-quantization
217+
layers and sequential blocks.
257218
258219
The core logic operates as follows:
259-
1. It automatically detects the model's structure, identifying the main
260-
embedding layer and a sequence of transformer blocks.
261-
2. It processes the model sequentially, one block at a time. For each
220+
221+
1. It processes the model sequentially, one block at a time. For each
262222
block, it uses temporary hooks to capture the input activations of
263223
each target layer during a forward pass with the calibration data.
264-
3. These captured activations are used to compute the Hessian matrix for
224+
2. These captured activations are used to compute the Hessian matrix for
265225
each layer's weights.
266-
4. The GPTQ algorithm is then applied to each layer to find the optimal
226+
3. The GPTQ algorithm is then applied to each layer to find the optimal
267227
quantized weights that minimize the error introduced.
268-
5. The output activations from the current block are then used as the
228+
4. The output activations from the current block are then used as the
269229
input for the next block, ensuring that quantization errors are
270230
accounted for throughout the model.
271231
272232
Args:
273-
model: The Keras model instance to be quantized. The function will
274-
attempt to automatically discover its structure.
275-
dataloader: An iterable providing calibration data. Each item should
276-
be a batch of token IDs suitable for the model's embedding layer.
233+
dataloader: An iterable providing calibration data.
277234
config: A GPTQConfiguration object.
235+
structure: A dictionary with keys "pre_block_layers" and
236+
"sequential_blocks".
237+
filters: Optional filters to exclude layers from quantization.
278238
279239
Raises:
280240
ValueError: If the function cannot automatically find an embedding
@@ -284,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
284244
num_samples = config.num_samples
285245

286246
logging.info("Starting model quantization...")
287-
embedding_layer = None
288-
transformer_blocks = []
289-
if hasattr(model, "backbone"):
290-
logging.info("Detected KerasHub model structure.")
291-
embedding_layer, transformer_blocks = _get_backbone_layers(model)
292-
else:
293-
logging.info("Detected custom model structure.")
294-
embedding_layer, transformer_blocks = _get_custom_layers(model)
295247

296-
if embedding_layer is None:
297-
raise ValueError(
298-
"Could not automatically find an embedding layer in the model."
299-
)
248+
pre_layers = structure.get("pre_block_layers", [])
249+
transformer_blocks = structure.get("sequential_blocks", [])
250+
300251
if not transformer_blocks:
301252
raise ValueError(
302-
"Could not automatically find any transformer-like blocks to "
303-
"quantize."
253+
"No sequential blocks found in the provided structure to quantize."
304254
)
305255

306-
# Initial inputs are the outputs of the token embedding layer
307-
inputs = [
308-
embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
309-
for batch in dataloader
310-
]
256+
# Initial inputs are the outputs of the pre-block layers
257+
inputs = []
258+
for batch in dataloader:
259+
batch = ops.convert_to_tensor(batch, dtype="int32")
260+
for layer in pre_layers:
261+
batch = layer(batch)
262+
inputs.append(batch)
263+
311264
num_samples = min(num_samples, len(inputs))
312265

313266
progbar = keras_utils.Progbar(target=len(transformer_blocks))
@@ -316,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
316269
logging.info(f"Quantizing Block {block_idx}")
317270
sub_layers_map = find_layers_in_block(block)
318271

272+
# Filter out layers that are not quantized with GPTQ
273+
final_sub_layers_map = {}
274+
for name, layer in sub_layers_map.items():
275+
if not should_quantize_layer(layer, filters):
276+
continue
277+
278+
final_sub_layers_map[name] = layer
279+
280+
sub_layers_map = final_sub_layers_map
281+
319282
if not sub_layers_map:
320283
logging.info(
321-
f" No Dense or EinsumDense layers found in block {block_idx}. "
322-
"Skipping."
284+
f" No quantizable layers found in block {block_idx}. Skipping."
323285
)
324286
else:
325287
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
@@ -357,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
357319
logging.info("Quantization process complete.")
358320

359321

360-
def gptq_quantize(model, config):
322+
def gptq_quantize(config, quantization_layer_structure, filters=None):
361323
"""
362-
Top-level function to quantize a Keras model using GPTQ.
324+
Quantizes the model using GPTQ.
325+
326+
Args:
327+
config: The GPTQ configuration.
328+
quantization_layer_structure: A dictionary describing the model's layer
329+
structure for quantization.
330+
filters: Optional filters to exclude layers from quantization.
363331
"""
364-
logging.info("Starting GPTQ quantization process...")
332+
if config.dataset is None or config.tokenizer is None:
333+
raise ValueError(
334+
"GPTQ quantization requires a dataset and a tokenizer. "
335+
"Please provide them in the `GPTQConfig`."
336+
)
337+
338+
if quantization_layer_structure is None:
339+
raise ValueError(
340+
"For 'gptq' mode, a valid quantization structure must be provided "
341+
"either via `config.quantization_layer_structure` or by overriding "
342+
"`model.get_quantization_layer_structure(mode)`. The structure "
343+
"should be a dictionary with keys 'pre_block_layers' and "
344+
"'sequential_blocks'."
345+
)
365346

366347
# Load all data needed from the generator/source in a single call.
367348
total_samples_to_request = config.num_samples
@@ -376,7 +357,12 @@ def gptq_quantize(model, config):
376357
# is now a NumPy array, which can be sliced and reused.
377358
calibration_dataloader = dataloader[: config.num_samples]
378359

379-
apply_gptq_layerwise(model, calibration_dataloader, config)
360+
apply_gptq_layerwise(
361+
calibration_dataloader,
362+
config,
363+
quantization_layer_structure,
364+
filters=filters,
365+
)
380366

381367

382368
def get_group_size_for_layer(layer, config):

0 commit comments

Comments
 (0)