1010from keras .src .dtype_policies .dtype_policy_map import DTypePolicyMap
1111from keras .src .layers import Dense
1212from keras .src .layers import EinsumDense
13- from keras .src .layers import Embedding
1413from keras .src .quantizers .gptq import GPTQ
1514from keras .src .quantizers .gptq_config import GPTQConfig
15+ from keras .src .quantizers .utils import should_quantize_layer
1616
1717
1818@contextmanager
@@ -193,38 +193,6 @@ def get_dataloader(
193193 return samples .astype (np .int32 )[:, None , :]
194194
195195
196- def _get_backbone_layers (model ):
197- """Extract embedding and transformer layers from a KerasHub model."""
198- backbone = model .backbone
199- if not hasattr (backbone , "transformer_layers" ):
200- raise ValueError (
201- "The model's backbone does not have a 'transformer_layers' "
202- "attribute. Please ensure you are using a standard KerasHub "
203- "transformer model."
204- )
205- transformer_blocks = backbone .transformer_layers
206-
207- embedding_layer = None
208- if hasattr (backbone , "token_embedding" ):
209- embedding_layer = backbone .token_embedding
210- elif hasattr (backbone , "embedding" ):
211- embedding_layer = backbone .embedding
212- return embedding_layer , transformer_blocks
213-
214-
215- def _get_custom_layers (model ):
216- """Heuristic for extracting embedding + transformer blocks from a custom
217- model."""
218- embedding_layer = None
219- transformer_blocks = []
220- for layer in model .layers :
221- if isinstance (layer , Embedding ) and embedding_layer is None :
222- embedding_layer = layer
223- elif getattr (layer , "_layers" , None ): # container-like block
224- transformer_blocks .append (layer )
225- return embedding_layer , transformer_blocks
226-
227-
228196def find_layers_in_block (block ):
229197 """
230198 Finds all Dense and EinsumDense layers in a transformer block.
@@ -242,39 +210,31 @@ def find_layers_in_block(block):
242210 return found_layers
243211
244212
245- def apply_gptq_layerwise (model , dataloader , config ):
213+ def apply_gptq_layerwise (dataloader , config , structure , filters = None ):
246214 """Applies GPTQ quantization layer-by-layer to a Keras model.
247215
248- This function is designed to work with common transformer architectures,
249- like those provided by KerasHub. It automatically discovers the model's
250- structure by first looking for the standard format: a `model.backbone`
251- attribute that contains a `transformer_layers` list.
252-
253- If a standard backbone is not found, it falls back to a heuristic for
254- custom models, where it assumes the first `keras.layers.Embedding` layer
255- is the input embedding and any subsequent container layers are the
256- transformer blocks to be quantized.
216+ This function uses the provided `structure` to identify pre-quantization
217+ layers and sequential blocks.
257218
258219 The core logic operates as follows:
259- 1. It automatically detects the model's structure, identifying the main
260- embedding layer and a sequence of transformer blocks.
261- 2. It processes the model sequentially, one block at a time. For each
220+
221+ 1. It processes the model sequentially, one block at a time. For each
262222 block, it uses temporary hooks to capture the input activations of
263223 each target layer during a forward pass with the calibration data.
264- 3 . These captured activations are used to compute the Hessian matrix for
224+ 2 . These captured activations are used to compute the Hessian matrix for
265225 each layer's weights.
266- 4 . The GPTQ algorithm is then applied to each layer to find the optimal
226+ 3 . The GPTQ algorithm is then applied to each layer to find the optimal
267227 quantized weights that minimize the error introduced.
268- 5 . The output activations from the current block are then used as the
228+ 4 . The output activations from the current block are then used as the
269229 input for the next block, ensuring that quantization errors are
270230 accounted for throughout the model.
271231
272232 Args:
273- model: The Keras model instance to be quantized. The function will
274- attempt to automatically discover its structure.
275- dataloader: An iterable providing calibration data. Each item should
276- be a batch of token IDs suitable for the model's embedding layer.
233+ dataloader: An iterable providing calibration data.
277234 config: A GPTQConfiguration object.
235+ structure: A dictionary with keys "pre_block_layers" and
236+ "sequential_blocks".
237+ filters: Optional filters to exclude layers from quantization.
278238
279239 Raises:
280240 ValueError: If the function cannot automatically find an embedding
@@ -284,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
284244 num_samples = config .num_samples
285245
286246 logging .info ("Starting model quantization..." )
287- embedding_layer = None
288- transformer_blocks = []
289- if hasattr (model , "backbone" ):
290- logging .info ("Detected KerasHub model structure." )
291- embedding_layer , transformer_blocks = _get_backbone_layers (model )
292- else :
293- logging .info ("Detected custom model structure." )
294- embedding_layer , transformer_blocks = _get_custom_layers (model )
295247
296- if embedding_layer is None :
297- raise ValueError (
298- "Could not automatically find an embedding layer in the model."
299- )
248+ pre_layers = structure .get ("pre_block_layers" , [])
249+ transformer_blocks = structure .get ("sequential_blocks" , [])
250+
300251 if not transformer_blocks :
301252 raise ValueError (
302- "Could not automatically find any transformer-like blocks to "
303- "quantize."
253+ "No sequential blocks found in the provided structure to quantize."
304254 )
305255
306- # Initial inputs are the outputs of the token embedding layer
307- inputs = [
308- embedding_layer (ops .convert_to_tensor (batch , dtype = "int32" ))
309- for batch in dataloader
310- ]
256+ # Initial inputs are the outputs of the pre-block layers
257+ inputs = []
258+ for batch in dataloader :
259+ batch = ops .convert_to_tensor (batch , dtype = "int32" )
260+ for layer in pre_layers :
261+ batch = layer (batch )
262+ inputs .append (batch )
263+
311264 num_samples = min (num_samples , len (inputs ))
312265
313266 progbar = keras_utils .Progbar (target = len (transformer_blocks ))
@@ -316,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
316269 logging .info (f"Quantizing Block { block_idx } " )
317270 sub_layers_map = find_layers_in_block (block )
318271
272+ # Filter out layers that are not quantized with GPTQ
273+ final_sub_layers_map = {}
274+ for name , layer in sub_layers_map .items ():
275+ if not should_quantize_layer (layer , filters ):
276+ continue
277+
278+ final_sub_layers_map [name ] = layer
279+
280+ sub_layers_map = final_sub_layers_map
281+
319282 if not sub_layers_map :
320283 logging .info (
321- f" No Dense or EinsumDense layers found in block { block_idx } . "
322- "Skipping."
284+ f" No quantizable layers found in block { block_idx } . Skipping."
323285 )
324286 else :
325287 logging .info (f"Found layers: { list (sub_layers_map .keys ())} " )
@@ -357,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
357319 logging .info ("Quantization process complete." )
358320
359321
360- def gptq_quantize (model , config ):
322+ def gptq_quantize (config , quantization_layer_structure , filters = None ):
361323 """
362- Top-level function to quantize a Keras model using GPTQ.
324+ Quantizes the model using GPTQ.
325+
326+ Args:
327+ config: The GPTQ configuration.
328+ quantization_layer_structure: A dictionary describing the model's layer
329+ structure for quantization.
330+ filters: Optional filters to exclude layers from quantization.
363331 """
364- logging .info ("Starting GPTQ quantization process..." )
332+ if config .dataset is None or config .tokenizer is None :
333+ raise ValueError (
334+ "GPTQ quantization requires a dataset and a tokenizer. "
335+ "Please provide them in the `GPTQConfig`."
336+ )
337+
338+ if quantization_layer_structure is None :
339+ raise ValueError (
340+ "For 'gptq' mode, a valid quantization structure must be provided "
341+ "either via `config.quantization_layer_structure` or by overriding "
342+ "`model.get_quantization_layer_structure(mode)`. The structure "
343+ "should be a dictionary with keys 'pre_block_layers' and "
344+ "'sequential_blocks'."
345+ )
365346
366347 # Load all data needed from the generator/source in a single call.
367348 total_samples_to_request = config .num_samples
@@ -376,7 +357,12 @@ def gptq_quantize(model, config):
376357 # is now a NumPy array, which can be sliced and reused.
377358 calibration_dataloader = dataloader [: config .num_samples ]
378359
379- apply_gptq_layerwise (model , calibration_dataloader , config )
360+ apply_gptq_layerwise (
361+ calibration_dataloader ,
362+ config ,
363+ quantization_layer_structure ,
364+ filters = filters ,
365+ )
380366
381367
382368def get_group_size_for_layer (layer , config ):
0 commit comments