From f0f7774beb47ffa48e0d1621ec997e4455db006e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:53:03 +0000 Subject: [PATCH 1/8] update budgets for a100 hardware weightclass --- algoperf/workloads/criteo1tb/workload.py | 4 ++-- algoperf/workloads/fastmri/workload.py | 4 ++-- .../workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 4 ++-- .../librispeech_conformer/workload.py | 4 ++-- .../librispeech_jax/workload.py | 6 +++++- .../librispeech_pytorch/workload.py | 6 +++++- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 4 ++-- docker/build_docker_images.sh | 14 ++++++------- scoring/performance_profile.py | 1 + scoring/score_submissions.py | 4 +++- scoring/scoring_utils.py | 20 +++++++++++++++++++ scoring/utils/run_workloads.py | 7 ++++++- .../workload_metadata_external_tuning.json | 2 +- 15 files changed, 62 insertions(+), 26 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 2cb7e5450..fb38eacc3 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7_703 # ~2.1 hours. + return 8915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 # 2 mins. + return 356 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 0b1ecfaa1..5a8afa2e9 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,11 +95,11 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 4_430 # ~1.2 hours + return 2745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: - return 80 + return 110 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index ef696e328..b5263e0a6 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 66_159 # ~18.4 hours + return 49918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 510 # 8.5 minutes. + return 1996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 2a0070ba4..f8f4f2659 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -88,11 +88,11 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 69_768 # ~19.4 hours + return 64_292 # ~17.8 hours @property def eval_period_time_sec(self) -> int: - return 7 * 60 # 7 mins. + return 2571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 791270719..327e8bc39 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,11 +80,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 58_015 # ~16.1 hours + return 43680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: - return 24 * 60 + return 1747 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3a320b0dd..2a8fd29d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,7 +100,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # ~12.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 672f3440f..119049b34 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36949 # 10.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 8717e46d6..53206200f 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 12_011 # ~3.3 hours + return 11303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 4 * 60 + return 452. # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 40e4262dd..d972a5486 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,11 +89,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43_336 # ~12.0 hours + return 16114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: - return 14 * 60 + return 644 @property def step_hint(self) -> int: diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 6b5e67ceb..22590b9fd 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -27,7 +27,7 @@ then GIT_BRANCH='main' # Set default argument fi -FRAMEWORKS=( "jax" "pythorch" "both" ) +FRAMEWORKS=( "jax" "pytorch") if [[ -n "$FRAMEWORK" ]]; then @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - echo $DOCKER_TAG_COMMAND - eval $DOCKER_TAG_COMMAND - echo $DOCKER_PUSH_COMMAND - eval $DOCKER_PUSH_COMMAND - echo "To pull container run: " - echo $DOCKER_PULL_COMMAND + # echo $DOCKER_TAG_COMMAND + # eval $DOCKER_TAG_COMMAND + # echo $DOCKER_PUSH_COMMAND + # eval $DOCKER_PUSH_COMMAND + # echo "To pull container run: " + # echo $DOCKER_PULL_COMMAND done diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 4f2ae9c57..b200c6865 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,6 +71,7 @@ 'wer', 'l1_loss', 'loss', + 'ppl' ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 3423df2e1..4b7bed2b5 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -123,6 +123,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): workload_df['accumulated_submission_time'] / workload_df['global_step'] ).iloc[-1][-1] + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) + # test metrics if include_test_split: test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( @@ -157,7 +159,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def get_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=False): """Summarizes the submission results into metric and time tables organized by workload. """ diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 5be6c790c..cb63eab4b 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -240,3 +240,23 @@ def get_workload_metrics_and_targets(workload, split='validation'): metric = f'test/{metric_name}' target = workload_obj.test_target_value return metric, target + + +def get_workload_stephint(workload): + workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) + framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) + workload_metadata = copy.copy(WORKLOADS[workload_name]) + + # Extend path according to framework. + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) + workload_init_kwargs = {} + workload_obj = workloads_registry.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) + return workload_obj.step_hint diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..c6764e9de 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -241,7 +241,8 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - run_key = prng.fold_in(rng_subkey, hash(workload)) + workload_foldin = hash(workload) % 9 + run_key = prng.fold_in(rng_subkey, workload_foldin) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() @@ -270,6 +271,10 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' +<<<<<<< Updated upstream +======= + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' +>>>>>>> Stashed changes f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..3d9f78ca1 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -24,7 +24,7 @@ "dataset": "librispeech" }, "criteo1tb": { - "max_steps": 10666, + "max_steps": 15666, "dataset": "criteo1tb" }, "librispeech_conformer": { From b93eb3ca97871ed30ddd6b08806a3bbc1ca0bdae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:56:48 +0000 Subject: [PATCH 2/8] formatting --- algoperf/workloads/criteo1tb/workload.py | 2 +- algoperf/workloads/fastmri/workload.py | 2 +- algoperf/workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 2 +- algoperf/workloads/librispeech_conformer/workload.py | 2 +- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 2 +- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index fb38eacc3..4d2196cd5 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,7 +95,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 8915 # ~2.4 hours. + return 8_915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 5a8afa2e9..b87dfc755 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,7 +95,7 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 2745 # ~0.7 hours + return 2_745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index b5263e0a6..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 49918 # ~13.8 hours + return 49_918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 1996 # approx 25 evals + return 1_996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index f8f4f2659..4da02614f 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 2571 # 7 mins. + return 2_571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 327e8bc39..5a0a546e4 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,7 +80,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43680 # ~16.1 hours + return 43_680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 119049b34..c6bb149f7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,7 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 36949 # 10.3 hours + return 36_949 # 10.3 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 53206200f..002576268 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 11303 # ~3.1 hours + return 11_303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 452. # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index d972a5486..2e232214e 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,7 +89,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 16114 # ~12.0 hours + return 16_114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: From 88b0e47fe9694d35d651a77c1acf8ea9491df5ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:57:34 +0000 Subject: [PATCH 3/8] revert changes to docker build shell script --- docker/build_docker_images.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 22590b9fd..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - # echo $DOCKER_TAG_COMMAND - # eval $DOCKER_TAG_COMMAND - # echo $DOCKER_PUSH_COMMAND - # eval $DOCKER_PUSH_COMMAND - # echo "To pull container run: " - # echo $DOCKER_PULL_COMMAND + echo $DOCKER_TAG_COMMAND + eval $DOCKER_TAG_COMMAND + echo $DOCKER_PUSH_COMMAND + eval $DOCKER_PUSH_COMMAND + echo "To pull container run: " + echo $DOCKER_PULL_COMMAND done From fa946d861aab6d803d88046edb05caa27a79c4ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 04:00:09 +0000 Subject: [PATCH 4/8] fix merge conflict --- scoring/utils/run_workloads.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index c6764e9de..d8e0172fa 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -271,10 +271,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' -<<<<<<< Updated upstream -======= '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' ->>>>>>> Stashed changes f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' From 4e564d5438398ab40da419413d7cac603dd96261 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 20 Nov 2025 23:17:17 +0000 Subject: [PATCH 5/8] update pytorch --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4de98f89..e1fc84987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,6 @@ jax_cpu = [ jax_gpu = [ "jax[cuda12]==0.7.0", "algoperf[jax_core_deps]", - "nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663 ] pytorch_cpu = [ @@ -113,8 +112,8 @@ pytorch_cpu = [ "torchvision==0.20.1" ] pytorch_gpu = [ - "torch==2.5.1", - "torchvision==0.20.1", + "torch==2.9.0", + "torchvision==0.24.0", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. ############################################################################### From 6f7d638adc190d9bce3f30ba3314c27dac1a8cc5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 1 Dec 2025 04:33:49 +0000 Subject: [PATCH 6/8] ImageNet and CIFAR mixed-precision support, need to debug slow pytorch - Introduced DTYPE enum to standardize data types (FLOAT32, FLOAT16, BFLOAT16) for JAX and PyTorch. - Updated input pipelines and model definitions in CIFAR and ImageNet workloads to utilize mixed precision. - Implemented casting policies for parameters and inputs using jmp and torch.autocast. --- algoperf/spec.py | 23 ++++++ .../cifar/cifar_jax/input_pipeline.py | 2 - algoperf/workloads/cifar/cifar_jax/models.py | 8 ++- .../workloads/cifar/cifar_jax/workload.py | 29 ++++++-- .../workloads/cifar/cifar_pytorch/models.py | 24 ++++++- .../workloads/cifar/cifar_pytorch/workload.py | 9 ++- algoperf/workloads/cifar/workload.py | 2 + .../imagenet_resnet/imagenet_jax/models.py | 8 ++- .../imagenet_resnet/imagenet_jax/workload.py | 28 ++++++-- .../imagenet_pytorch/models.py | 52 +++++++++++--- .../imagenet_pytorch/workload.py | 9 ++- .../workloads/imagenet_resnet/workload.py | 2 + .../imagenet_vit/imagenet_jax/models.py | 48 +++++++++---- .../imagenet_vit/imagenet_jax/workload.py | 10 ++- .../imagenet_vit/imagenet_pytorch/models.py | 71 +++++++++++++------ .../imagenet_vit/imagenet_pytorch/workload.py | 12 ++-- algoperf/workloads/ogbg/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 2 + .../pytorch_nadamw_full_budget.py | 10 +-- scoring/performance_profile.py | 2 +- submission_runner.py | 25 +++++-- 21 files changed, 288 insertions(+), 90 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index b86e55954..8dd00345c 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -6,11 +6,34 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import jax +import jax.numpy as jnp +import torch import torch.nn.functional as F from absl import logging from torch import nn +class DTYPE(enum.Enum): + FLOAT32 = 0 + FLOAT16 = 1 + BFLOAT16 = 2 + + +# Mapping from DTYPE enum to JAX dtypes +JAX_DTYPE_MAP = { + DTYPE.FLOAT32: jnp.float32, + DTYPE.FLOAT16: jnp.float16, + DTYPE.BFLOAT16: jnp.bfloat16, +} + +# Mapping from DTYPE enum to PyTorch dtypes +PYTORCH_DTYPE_MAP = { + DTYPE.FLOAT32: torch.float32, + DTYPE.FLOAT16: torch.float16, + DTYPE.BFLOAT16: torch.bfloat16, +} + + class LossType(enum.Enum): SOFTMAX_CROSS_ENTROPY = 0 SIGMOID_CROSS_ENTROPY = 1 diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 7fbc95bc6..307e9e705 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -11,7 +11,6 @@ import jax import tensorflow as tf import tensorflow_datasets as tfds -from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np @@ -186,5 +185,4 @@ def create_input_iter( ), ds, ) - it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 95238c997..9a4f7fd96 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -31,7 +31,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: bool = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: @@ -41,7 +41,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - dtype=self.dtype, + param_dtype=self.dtype, ) x = conv( @@ -66,7 +66,9 @@ def __call__( x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + self.num_classes, + kernel_init=nn.initializers.normal(), + param_dtype=self.dtype, )(x) return x diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index defc30121..e6bc5b419 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp +import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -18,6 +19,17 @@ class CifarWorkload(BaseCifarWorkload): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + output_dtype = compute_dtype + self._mp_policy = jmp.Policy( + compute_dtype=compute_dtype, + param_dtype=param_dtype, + output_dtype=output_dtype, + ) + def _build_cifar_dataset( self, data_rng: spec.RandomState, @@ -80,7 +92,8 @@ def sync_batch_stats( def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') - model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + model = model_cls(num_classes=self._num_classes, dtype=param_dtype) self._model = model input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)( @@ -89,7 +102,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) params = jax_sharding_utils.replicate(params) return params, model_state @@ -110,24 +123,32 @@ def model_fn( del mode del rng del dropout_rate + # Cast params and inputs to compute dtype + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) + # Cast logits to output dtype + logits = self._mp_policy.cast_to_output(logits) return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) + # Cast logits to output dtype + logits = self._mp_policy.cast_to_output(logits) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index 0e08f5c5a..b2b37c001 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -29,11 +29,13 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer + self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -49,7 +51,13 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + 3, + self.inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False, + dtype=dtype, ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -63,7 +71,7 @@ def __init__( self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) self.reset_parameters() def reset_parameters(self) -> None: @@ -105,7 +113,15 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ( + 'conv', + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + dtype=self.dtype, + ), + ), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -122,6 +138,7 @@ def _make_layer( self.base_width, previous_dilation, norm_layer, + dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -134,6 +151,7 @@ def _make_layer( base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, + dtype=self.dtype, ) ) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..141bef922 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -25,6 +25,8 @@ def __init__(self, *args, **kwargs) -> None: # Is set in submission_runner.py for workloads with PyTorch evaluation # data loaders via the `eval_num_workers` property. self._eval_num_workers = None + self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype] + self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] @property def eval_num_workers(self) -> int: @@ -128,7 +130,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return self._model, None torch.random.manual_seed(rng[0]) - self._model = resnet18(num_classes=self._num_classes) + self._model = resnet18( + num_classes=self._num_classes, dtype=self._param_dtype_pt + ) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -175,7 +179,8 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt): + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index 31636807c..6866bc918 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -16,6 +16,8 @@ class BaseCifarWorkload(spec.Workload): _num_classes: int = 10 + _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 + _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index ee1ddf427..41551d4d2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -90,7 +90,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: Optional[bool] = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm @@ -99,7 +99,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - dtype=self.dtype, + param_dtype=self.dtype, ) x = conv( @@ -125,7 +125,9 @@ def __call__( )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + self.num_classes, + kernel_init=nn.initializers.normal(), + param_dtype=self.dtype, )(x) return x diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index f73a1b26e..d7a8ede67 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ import jax import jax.numpy as jnp +import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -29,6 +30,17 @@ class ImagenetResNetWorkload(BaseImagenetResNetWorkload): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + output_dtype = compute_dtype + self._mp_policy = jmp.Policy( + compute_dtype=compute_dtype, + param_dtype=param_dtype, + output_dtype=output_dtype, + ) + def _build_dataset( self, data_rng: spec.RandomState, @@ -89,11 +101,12 @@ def init_model_fn( else: act_fnc = nn.relu + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] model = model_cls( num_classes=self._num_classes, act=act_fnc, bn_init_scale=self.bn_init_scale, - dtype=jnp.float32, + dtype=param_dtype, ) self._model = model input_shape = (1, 224, 224, 3) @@ -159,25 +172,28 @@ def model_fn( del mode del rng del dropout_rate + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) variables = {'params': params, **model_state} if update_batch_norm: - logits, new_model_state = self._model.apply( + logits, model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) - return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - return logits, model_state + logits = self._mp_policy.cast_to_output(logits) + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index c980faa06..f24ba66b9 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -20,6 +20,7 @@ def conv3x3( stride: int = 1, groups: int = 1, dilation: int = 1, + dtype: torch.dtype = torch.float32, ) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( @@ -31,13 +32,24 @@ def conv3x3( groups=groups, bias=False, dilation=dilation, + dtype=dtype, ) -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: +def conv1x1( + in_planes: int, + out_planes: int, + stride: int = 1, + dtype: torch.dtype = torch.float32, +) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, out_planes, kernel_size=1, stride=stride, bias=False + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False, + dtype=dtype, ) @@ -57,6 +69,7 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -67,10 +80,10 @@ def __init__( raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv3x3(inplanes, planes, stride) + self.conv1 = conv3x3(inplanes, planes, stride, dtype=dtype) self.bn1 = norm_layer(planes) self.act_fnc = act_fnc - self.conv2 = conv3x3(planes, planes) + self.conv2 = conv3x3(planes, planes, dtype=dtype) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride @@ -110,6 +123,7 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -117,11 +131,11 @@ def __init__( width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv1x1(inplanes, width) + self.conv1 = conv1x1(inplanes, width, dtype=dtype) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.conv2 = conv3x3(width, width, stride, groups, dilation, dtype=dtype) self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) + self.conv3 = conv1x1(width, planes * self.expansion, dtype=dtype) self.bn3 = norm_layer(planes * self.expansion) self.act_fnc = act_fnc self.downsample = downsample @@ -163,11 +177,13 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), bn_init_scale: float = 0.0, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer + self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -183,7 +199,13 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + dtype=dtype, ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc @@ -214,7 +236,7 @@ def __init__( dilate=replace_stride_with_dilation[2], ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -256,7 +278,15 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ( + 'conv', + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + dtype=self.dtype, + ), + ), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -274,6 +304,7 @@ def _make_layer( previous_dilation, norm_layer, act_fnc, + dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -287,6 +318,7 @@ def _make_layer( dilation=self.dilation, norm_layer=norm_layer, act_fnc=act_fnc, + dtype=self.dtype, ) ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..3a88245ae 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -178,7 +178,10 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: act_fnc = torch.nn.ReLU(inplace=True) - model = resnet50(act_fnc=act_fnc, bn_init_scale=self.bn_init_scale) + param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] + model = resnet50( + act_fnc=act_fnc, bn_init_scale=self.bn_init_scale, dtype=param_dtype + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -229,8 +232,10 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + with torch.autocast(device_type='cuda', dtype=compute_dtype): + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index de8458c92..bc5982f1d 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -8,6 +8,8 @@ class BaseImagenetResNetWorkload(spec.Workload): _num_classes: int = 1000 + _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 + _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index e86233011..2e4630701 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -42,6 +42,7 @@ class MlpBlock(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False dropout_rate: float = DROPOUT_RATE + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -54,15 +55,15 @@ def __call__( } d = x.shape[2] - x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) + x = nn.Dense(self.mlp_dim or 4 * d, param_dtype=self.dtype, **inits)(x) x = nn.gelu(x) if self.use_glu: - y = nn.Dense(self.mlp_dim, **inits)(x) + y = nn.Dense(self.mlp_dim, param_dtype=self.dtype, **inits)(x) x = x * y x = Dropout(dropout_rate)(x, train, rate=dropout_rate) - x = nn.Dense(d, **inits)(x) + x = nn.Dense(d, param_dtype=self.dtype, **inits)(x) return x @@ -74,25 +75,30 @@ class Encoder1DBlock(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate ) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', + param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - y = nn.LayerNorm(name='LayerNorm_2')(x) + y = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dtype=self.dtype, + name='MlpBlock_3', )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y @@ -103,21 +109,23 @@ def __call__( kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', + param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) + x = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, + dtype=self.dtype, name='MlpBlock_3', dropout_rate=dropout_rate, )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) + x = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) return x @@ -130,6 +138,7 @@ class Encoder(nn.Module): num_heads: int = 12 use_glu: bool = False use_post_layer_norm: bool = False + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -143,9 +152,10 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) + return nn.LayerNorm(name='encoder_layernorm', param_dtype=self.dtype)(x) else: return x @@ -156,12 +166,13 @@ class MAPHead(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param( - 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), self.dtype ) probe = jnp.tile(probe, [n, 1, 1]) @@ -169,10 +180,13 @@ def __call__(self, x, dropout_rate=DROPOUT_RATE): num_heads=self.num_heads, use_bias=True, kernel_init=nn.initializers.xavier_uniform(), + param_dtype=self.dtype, )(probe, x) - y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) + y = nn.LayerNorm(param_dtype=self.dtype)(x) + x = x + MlpBlock( + mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, dtype=self.dtype + )(y) return x[:, 0] @@ -192,6 +206,7 @@ class ViT(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False use_map: bool = False + dtype: jnp.dtype = jnp.float32 def get_posemb( self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 @@ -209,6 +224,7 @@ def __call__( strides=self.patch_size, padding='VALID', name='conv_patch_extract', + param_dtype=self.dtype, )(x) n, h, w, c = x.shape @@ -225,6 +241,7 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, name='Transformer', )(x, train=not train, dropout_rate=dropout_rate) @@ -233,18 +250,21 @@ def __call__( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, + dtype=self.dtype, )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, name='pre_logits') + hid = nn.Dense(rep_size, name='pre_logits', param_dtype=self.dtype) x = nn.tanh(hid(x)) if self.num_classes: kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense(self.num_classes, name='head', **kw) + head = nn.Dense( + self.num_classes, name='head', param_dtype=self.dtype, **kw + ) x = head(x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 8a33aeb47..6819a4862 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,11 +32,13 @@ def initialized( return params, model_state def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, + dtype=param_dtype, **decode_variant('S/16'), ) params, model_state = self.initialized(rng, self._model) @@ -62,15 +64,19 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm - del use_running_average_bn + # Cast params and inputs to compute dtype + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( {'params': params}, - augmented_and_preprocessed_input_batch['inputs'], + inputs, rngs={'dropout': rng}, train=train, dropout_rate=dropout_rate, ) + logits = self._mp_policy.cast_to_output(logits) return logits, None def _eval_model_on_split( diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..6dfb5fddf 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -46,22 +46,24 @@ def __init__( width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. use_glu: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu + self.dtype = dtype - self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.linear1 = nn.Linear(self.width, self.mlp_dim, dtype=self.dtype) self.act_fnc = nn.GELU(approximate='tanh') if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim, dtype=self.dtype) else: self.glu_linear = None - self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.linear2 = nn.Linear(self.mlp_dim, self.width, dtype=self.dtype) self.reset_parameters() @@ -85,14 +87,18 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: return x +# TODO(rka97): switch this to built-in attention with cudnn class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__(self, width: int, num_heads: int = 8) -> None: + def __init__( + self, width: int, num_heads: int = 8, dtype: Any = torch.float32 + ) -> None: super().__init__() self.width = width self.num_heads = num_heads + self.dtype = dtype assert width % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.' @@ -101,10 +107,10 @@ def __init__(self, width: int, num_heads: int = 8) -> None: self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.query = nn.Linear(self.width, self.all_head_dim) - self.key = nn.Linear(self.width, self.all_head_dim) - self.value = nn.Linear(self.width, self.all_head_dim) - self.out = nn.Linear(self.width, self.width) + self.query = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.key = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.value = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.out = nn.Linear(self.width, self.width, dtype=self.dtype) self.reset_parameters() def reset_parameters(self) -> None: @@ -150,6 +156,7 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() @@ -158,12 +165,18 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm + self.dtype = dtype - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) - self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.self_attention1 = SelfAttention( + self.width, self.num_heads, dtype=self.dtype + ) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) self.mlp3 = MlpBlock( - width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dtype=self.dtype, ) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: @@ -203,6 +216,7 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() @@ -212,6 +226,7 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm + self.dtype = dtype self.net = nn.ModuleList( [ @@ -221,13 +236,14 @@ def __init__( self.num_heads, self.use_glu, self.use_post_layer_norm, + dtype=self.dtype, ) for _ in range(depth) ] ) if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) else: self.encoder_norm = None @@ -245,21 +261,32 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" def __init__( - self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 + self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + dtype: torch.dtype = torch.float32, ): super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.dtype = dtype self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True + self.width, + num_heads=self.num_heads, + self_attn=False, + bias=True, + dtype=self.dtype, + ) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.mlp = MlpBlock( + width=self.width, mlp_dim=self.mlp_dim, dtype=self.dtype ) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) - self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape @@ -310,7 +337,7 @@ def __init__( if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size) + self.pre_logits = nn.Linear(self.width, rep_size, dtype=self.dtype) self.conv_patch_extract = nn.Conv2d( self.channels, @@ -318,6 +345,7 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid', + dtype=self.dtype, ) self.encoder = Encoder( @@ -327,13 +355,16 @@ def __init__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, ) if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes) + self.head = nn.Linear(self.width, self.num_classes, dtype=self.dtype) if self.use_map: - self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + self.map = MAPHead( + self.width, self.mlp_dim, self.num_heads, dtype=self.dtype + ) else: self.map = None diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 9c6faf70b..bfef3e0a9 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,11 +23,13 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) + param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, + dtype=param_dtype, **decode_variant('S/16'), ) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -70,11 +72,13 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - logits_batch = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate, - ) + with torch.autocast(device_type='cuda', dtype=compute_dtype): + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..a6f36fd30 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -396,6 +396,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..285727885 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -315,13 +314,6 @@ def update_params( }, global_step, ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), - ) - return (optimizer_state, current_param_container, new_model_state) @@ -372,6 +364,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..84ae3307b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,6 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'cifar', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -409,10 +410,15 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() - - train_state['accumulated_submission_time'] += ( - train_step_end_time - train_state['last_step_end_time'] - ) + step_time = train_step_end_time - train_state['last_step_end_time'] + train_state['accumulated_submission_time'] += step_time + # Log training progress periodically + if global_step % 10 == 0: + logging.info( + f'Step: {global_step}, ' + f'\tLast step time: {step_time:.4f}s, ' + f'\tTotal time: {train_state["accumulated_submission_time"]:.2f}s' + ) # Check if submission is eligible for an untimed eval. if ( @@ -512,10 +518,19 @@ def train_once( latest_eval_result['accumulated_logging_time'] = train_state[ 'accumulated_logging_time' ] + # Calculate average per-step time + avg_per_step_time = ( + train_state['accumulated_submission_time'] / global_step + if global_step > 0 + else 0.0 + ) + latest_eval_result['avg_per_step_time'] = avg_per_step_time time_since_start = latest_eval_result['total_duration'] logging.info( f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}' + f'\tStep: {global_step}, ' + f'\tAvg per-step time: {avg_per_step_time:.4f}s, ' + f'\t{latest_eval_result}' ) eval_results.append((global_step, latest_eval_result)) From 68060195b8d3aa79848d32bd5d0ea8040a634b18 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 11 Dec 2025 03:15:03 +0000 Subject: [PATCH 7/8] Revert "ImageNet and CIFAR mixed-precision support, need to debug slow pytorch" This reverts commit 6f7d638adc190d9bce3f30ba3314c27dac1a8cc5. --- algoperf/spec.py | 23 ------ .../cifar/cifar_jax/input_pipeline.py | 2 + algoperf/workloads/cifar/cifar_jax/models.py | 8 +-- .../workloads/cifar/cifar_jax/workload.py | 29 ++------ .../workloads/cifar/cifar_pytorch/models.py | 24 +------ .../workloads/cifar/cifar_pytorch/workload.py | 9 +-- algoperf/workloads/cifar/workload.py | 2 - .../imagenet_resnet/imagenet_jax/models.py | 8 +-- .../imagenet_resnet/imagenet_jax/workload.py | 28 ++------ .../imagenet_pytorch/models.py | 52 +++----------- .../imagenet_pytorch/workload.py | 9 +-- .../workloads/imagenet_resnet/workload.py | 2 - .../imagenet_vit/imagenet_jax/models.py | 48 ++++--------- .../imagenet_vit/imagenet_jax/workload.py | 10 +-- .../imagenet_vit/imagenet_pytorch/models.py | 71 ++++++------------- .../imagenet_vit/imagenet_pytorch/workload.py | 12 ++-- algoperf/workloads/ogbg/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 2 - .../pytorch_nadamw_full_budget.py | 10 ++- scoring/performance_profile.py | 2 +- submission_runner.py | 25 ++----- 21 files changed, 90 insertions(+), 288 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index 8dd00345c..b86e55954 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -6,34 +6,11 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import jax -import jax.numpy as jnp -import torch import torch.nn.functional as F from absl import logging from torch import nn -class DTYPE(enum.Enum): - FLOAT32 = 0 - FLOAT16 = 1 - BFLOAT16 = 2 - - -# Mapping from DTYPE enum to JAX dtypes -JAX_DTYPE_MAP = { - DTYPE.FLOAT32: jnp.float32, - DTYPE.FLOAT16: jnp.float16, - DTYPE.BFLOAT16: jnp.bfloat16, -} - -# Mapping from DTYPE enum to PyTorch dtypes -PYTORCH_DTYPE_MAP = { - DTYPE.FLOAT32: torch.float32, - DTYPE.FLOAT16: torch.float16, - DTYPE.BFLOAT16: torch.bfloat16, -} - - class LossType(enum.Enum): SOFTMAX_CROSS_ENTROPY = 0 SIGMOID_CROSS_ENTROPY = 1 diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 307e9e705..7fbc95bc6 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -11,6 +11,7 @@ import jax import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np @@ -185,4 +186,5 @@ def create_input_iter( ), ds, ) + it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 9a4f7fd96..95238c997 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -31,7 +31,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: bool = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: @@ -41,7 +41,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - param_dtype=self.dtype, + dtype=self.dtype, ) x = conv( @@ -66,9 +66,7 @@ def __call__( x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - param_dtype=self.dtype, + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype )(x) return x diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index e6bc5b419..defc30121 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -5,7 +5,6 @@ import jax import jax.numpy as jnp -import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -19,17 +18,6 @@ class CifarWorkload(BaseCifarWorkload): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] - output_dtype = compute_dtype - self._mp_policy = jmp.Policy( - compute_dtype=compute_dtype, - param_dtype=param_dtype, - output_dtype=output_dtype, - ) - def _build_cifar_dataset( self, data_rng: spec.RandomState, @@ -92,8 +80,7 @@ def sync_batch_stats( def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] - model = model_cls(num_classes=self._num_classes, dtype=param_dtype) + model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)( @@ -102,7 +89,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.replicate(model_state) + model_state = jax_sharding_utils.replicate(params) params = jax_sharding_utils.replicate(params) return params, model_state @@ -123,32 +110,24 @@ def model_fn( del mode del rng del dropout_rate - # Cast params and inputs to compute dtype - params, inputs = self._mp_policy.cast_to_compute( - (params, augmented_and_preprocessed_input_batch['inputs']) - ) variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) - # Cast logits to output dtype - logits = self._mp_policy.cast_to_output(logits) return logits, new_model_state else: logits = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - # Cast logits to output dtype - logits = self._mp_policy.cast_to_output(logits) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index b2b37c001..0e08f5c5a 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -29,13 +29,11 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer - self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -51,13 +49,7 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, - self.inplanes, - kernel_size=3, - stride=1, - padding=1, - bias=False, - dtype=dtype, + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -71,7 +63,7 @@ def __init__( self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) - self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) + self.fc = nn.Linear(512 * block.expansion, num_classes) self.reset_parameters() def reset_parameters(self) -> None: @@ -113,15 +105,7 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ( - 'conv', - conv1x1( - self.inplanes, - planes * block.expansion, - stride, - dtype=self.dtype, - ), - ), + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -138,7 +122,6 @@ def _make_layer( self.base_width, previous_dilation, norm_layer, - dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -151,7 +134,6 @@ def _make_layer( base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, - dtype=self.dtype, ) ) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index 141bef922..a6e8569cc 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -25,8 +25,6 @@ def __init__(self, *args, **kwargs) -> None: # Is set in submission_runner.py for workloads with PyTorch evaluation # data loaders via the `eval_num_workers` property. self._eval_num_workers = None - self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype] - self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] @property def eval_num_workers(self) -> int: @@ -130,9 +128,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return self._model, None torch.random.manual_seed(rng[0]) - self._model = resnet18( - num_classes=self._num_classes, dtype=self._param_dtype_pt - ) + self._model = resnet18(num_classes=self._num_classes) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -179,8 +175,7 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index 6866bc918..31636807c 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -16,8 +16,6 @@ class BaseCifarWorkload(spec.Workload): _num_classes: int = 10 - _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 - _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index 41551d4d2..ee1ddf427 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -90,7 +90,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: Optional[bool] = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm @@ -99,7 +99,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - param_dtype=self.dtype, + dtype=self.dtype, ) x = conv( @@ -125,9 +125,7 @@ def __call__( )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - param_dtype=self.dtype, + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype )(x) return x diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index d7a8ede67..f73a1b26e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp -import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -30,17 +29,6 @@ class ImagenetResNetWorkload(BaseImagenetResNetWorkload): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] - output_dtype = compute_dtype - self._mp_policy = jmp.Policy( - compute_dtype=compute_dtype, - param_dtype=param_dtype, - output_dtype=output_dtype, - ) - def _build_dataset( self, data_rng: spec.RandomState, @@ -101,12 +89,11 @@ def init_model_fn( else: act_fnc = nn.relu - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] model = model_cls( num_classes=self._num_classes, act=act_fnc, bn_init_scale=self.bn_init_scale, - dtype=param_dtype, + dtype=jnp.float32, ) self._model = model input_shape = (1, 224, 224, 3) @@ -172,28 +159,25 @@ def model_fn( del mode del rng del dropout_rate - params, inputs = self._mp_policy.cast_to_compute( - (params, augmented_and_preprocessed_input_batch['inputs']) - ) variables = {'params': params, **model_state} if update_batch_norm: - logits, model_state = self._model.apply( + logits, new_model_state = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) + return logits, new_model_state else: logits = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - logits = self._mp_policy.cast_to_output(logits) - return logits, model_state + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index f24ba66b9..c980faa06 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -20,7 +20,6 @@ def conv3x3( stride: int = 1, groups: int = 1, dilation: int = 1, - dtype: torch.dtype = torch.float32, ) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( @@ -32,24 +31,13 @@ def conv3x3( groups=groups, bias=False, dilation=dilation, - dtype=dtype, ) -def conv1x1( - in_planes: int, - out_planes: int, - stride: int = 1, - dtype: torch.dtype = torch.float32, -) -> nn.Conv2d: +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False, - dtype=dtype, + in_planes, out_planes, kernel_size=1, stride=stride, bias=False ) @@ -69,7 +57,6 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -80,10 +67,10 @@ def __init__( raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv3x3(inplanes, planes, stride, dtype=dtype) + self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.act_fnc = act_fnc - self.conv2 = conv3x3(planes, planes, dtype=dtype) + self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride @@ -123,7 +110,6 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -131,11 +117,11 @@ def __init__( width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv1x1(inplanes, width, dtype=dtype) + self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation, dtype=dtype) + self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion, dtype=dtype) + self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.act_fnc = act_fnc self.downsample = downsample @@ -177,13 +163,11 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), bn_init_scale: float = 0.0, - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer - self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -199,13 +183,7 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, - self.inplanes, - kernel_size=7, - stride=2, - padding=3, - bias=False, - dtype=dtype, + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc @@ -236,7 +214,7 @@ def __init__( dilate=replace_stride_with_dilation[2], ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) + self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -278,15 +256,7 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ( - 'conv', - conv1x1( - self.inplanes, - planes * block.expansion, - stride, - dtype=self.dtype, - ), - ), + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -304,7 +274,6 @@ def _make_layer( previous_dilation, norm_layer, act_fnc, - dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -318,7 +287,6 @@ def _make_layer( dilation=self.dilation, norm_layer=norm_layer, act_fnc=act_fnc, - dtype=self.dtype, ) ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3a88245ae..d5366c60d 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -178,10 +178,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: act_fnc = torch.nn.ReLU(inplace=True) - param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] - model = resnet50( - act_fnc=act_fnc, bn_init_scale=self.bn_init_scale, dtype=param_dtype - ) + model = resnet50(act_fnc=act_fnc, bn_init_scale=self.bn_init_scale) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -232,10 +229,8 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } - compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - with torch.autocast(device_type='cuda', dtype=compute_dtype): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index bc5982f1d..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -8,8 +8,6 @@ class BaseImagenetResNetWorkload(spec.Workload): _num_classes: int = 1000 - _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 - _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 2e4630701..e86233011 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -42,7 +42,6 @@ class MlpBlock(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False dropout_rate: float = DROPOUT_RATE - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -55,15 +54,15 @@ def __call__( } d = x.shape[2] - x = nn.Dense(self.mlp_dim or 4 * d, param_dtype=self.dtype, **inits)(x) + x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) x = nn.gelu(x) if self.use_glu: - y = nn.Dense(self.mlp_dim, param_dtype=self.dtype, **inits)(x) + y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y x = Dropout(dropout_rate)(x, train, rate=dropout_rate) - x = nn.Dense(d, param_dtype=self.dtype, **inits)(x) + x = nn.Dense(d, **inits)(x) return x @@ -75,30 +74,25 @@ class Encoder1DBlock(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False dropout_rate: float = 0.0 - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate ) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) + y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', - param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - y = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) + y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dtype=self.dtype, - name='MlpBlock_3', + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y @@ -109,23 +103,21 @@ def __call__( kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', - param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) + x = nn.LayerNorm(name='LayerNorm_0')(x) y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dtype=self.dtype, name='MlpBlock_3', dropout_rate=dropout_rate, )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) + x = nn.LayerNorm(name='LayerNorm_2')(x) return x @@ -138,7 +130,6 @@ class Encoder(nn.Module): num_heads: int = 12 use_glu: bool = False use_post_layer_norm: bool = False - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -152,10 +143,9 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dtype=self.dtype, )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm', param_dtype=self.dtype)(x) + return nn.LayerNorm(name='encoder_layernorm')(x) else: return x @@ -166,13 +156,12 @@ class MAPHead(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout_rate: float = 0.0 - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param( - 'probe', nn.initializers.xavier_uniform(), (1, 1, d), self.dtype + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype ) probe = jnp.tile(probe, [n, 1, 1]) @@ -180,13 +169,10 @@ def __call__(self, x, dropout_rate=DROPOUT_RATE): num_heads=self.num_heads, use_bias=True, kernel_init=nn.initializers.xavier_uniform(), - param_dtype=self.dtype, )(probe, x) - y = nn.LayerNorm(param_dtype=self.dtype)(x) - x = x + MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, dtype=self.dtype - )(y) + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) return x[:, 0] @@ -206,7 +192,6 @@ class ViT(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False use_map: bool = False - dtype: jnp.dtype = jnp.float32 def get_posemb( self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 @@ -224,7 +209,6 @@ def __call__( strides=self.patch_size, padding='VALID', name='conv_patch_extract', - param_dtype=self.dtype, )(x) n, h, w, c = x.shape @@ -241,7 +225,6 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dtype=self.dtype, name='Transformer', )(x, train=not train, dropout_rate=dropout_rate) @@ -250,21 +233,18 @@ def __call__( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, - dtype=self.dtype, )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, name='pre_logits', param_dtype=self.dtype) + hid = nn.Dense(rep_size, name='pre_logits') x = nn.tanh(hid(x)) if self.num_classes: kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense( - self.num_classes, name='head', param_dtype=self.dtype, **kw - ) + head = nn.Dense(self.num_classes, name='head', **kw) x = head(x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 6819a4862..8a33aeb47 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,13 +32,11 @@ def initialized( return params, model_state def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - dtype=param_dtype, **decode_variant('S/16'), ) params, model_state = self.initialized(rng, self._model) @@ -64,19 +62,15 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm - # Cast params and inputs to compute dtype - params, inputs = self._mp_policy.cast_to_compute( - (params, augmented_and_preprocessed_input_batch['inputs']) - ) + del use_running_average_bn train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( {'params': params}, - inputs, + augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, train=train, dropout_rate=dropout_rate, ) - logits = self._mp_policy.cast_to_output(logits) return logits, None def _eval_model_on_split( diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 6dfb5fddf..fc2a3cd46 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -46,24 +46,22 @@ def __init__( width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. use_glu: bool = False, - dtype: Any = torch.float32, ) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dtype = dtype - self.linear1 = nn.Linear(self.width, self.mlp_dim, dtype=self.dtype) + self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim, dtype=self.dtype) + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) else: self.glu_linear = None - self.linear2 = nn.Linear(self.mlp_dim, self.width, dtype=self.dtype) + self.linear2 = nn.Linear(self.mlp_dim, self.width) self.reset_parameters() @@ -87,18 +85,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: return x -# TODO(rka97): switch this to built-in attention with cudnn class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__( - self, width: int, num_heads: int = 8, dtype: Any = torch.float32 - ) -> None: + def __init__(self, width: int, num_heads: int = 8) -> None: super().__init__() self.width = width self.num_heads = num_heads - self.dtype = dtype assert width % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.' @@ -107,10 +101,10 @@ def __init__( self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.query = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) - self.key = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) - self.value = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) - self.out = nn.Linear(self.width, self.width, dtype=self.dtype) + self.query = nn.Linear(self.width, self.all_head_dim) + self.key = nn.Linear(self.width, self.all_head_dim) + self.value = nn.Linear(self.width, self.all_head_dim) + self.out = nn.Linear(self.width, self.width) self.reset_parameters() def reset_parameters(self) -> None: @@ -156,7 +150,6 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, - dtype: Any = torch.float32, ) -> None: super().__init__() @@ -165,18 +158,12 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.dtype = dtype - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) - self.self_attention1 = SelfAttention( - self.width, self.num_heads, dtype=self.dtype - ) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) + self.self_attention1 = SelfAttention(self.width, self.num_heads) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( - width=self.width, - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dtype=self.dtype, + width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu ) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: @@ -216,7 +203,6 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, - dtype: Any = torch.float32, ) -> None: super().__init__() @@ -226,7 +212,6 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.dtype = dtype self.net = nn.ModuleList( [ @@ -236,14 +221,13 @@ def __init__( self.num_heads, self.use_glu, self.use_post_layer_norm, - dtype=self.dtype, ) for _ in range(depth) ] ) if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) else: self.encoder_norm = None @@ -261,32 +245,21 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" def __init__( - self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - dtype: torch.dtype = torch.float32, + self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 ): super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads - self.dtype = dtype self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, - num_heads=self.num_heads, - self_attn=False, - bias=True, - dtype=self.dtype, - ) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) - self.mlp = MlpBlock( - width=self.width, mlp_dim=self.mlp_dim, dtype=self.dtype + self.width, num_heads=self.num_heads, self_attn=False, bias=True ) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape @@ -337,7 +310,7 @@ def __init__( if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size, dtype=self.dtype) + self.pre_logits = nn.Linear(self.width, rep_size) self.conv_patch_extract = nn.Conv2d( self.channels, @@ -345,7 +318,6 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid', - dtype=self.dtype, ) self.encoder = Encoder( @@ -355,16 +327,13 @@ def __init__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dtype=self.dtype, ) if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes, dtype=self.dtype) + self.head = nn.Linear(self.width, self.num_classes) if self.use_map: - self.map = MAPHead( - self.width, self.mlp_dim, self.num_heads, dtype=self.dtype - ) + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) else: self.map = None diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index bfef3e0a9..9c6faf70b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,13 +23,11 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) - param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - dtype=param_dtype, **decode_variant('S/16'), ) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -72,13 +70,11 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } - compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - with torch.autocast(device_type='cuda', dtype=compute_dtype): - logits_batch = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate, - ) + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 771b103a0..002576268 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index a6f36fd30..0577cd4e0 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -396,8 +396,6 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'cifar': - return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 285727885..0b32199ba 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,6 +5,7 @@ import torch import torch.distributed.nn as dist_nn +from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -314,6 +315,13 @@ def update_params( }, global_step, ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + return (optimizer_state, current_param_container, new_model_state) @@ -364,8 +372,6 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'cifar': - return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 043a65791..b200c6865 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl', + 'ppl' ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/submission_runner.py b/submission_runner.py index 84ae3307b..552c99b79 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,7 +266,6 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', - 'cifar', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -410,15 +409,10 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() - step_time = train_step_end_time - train_state['last_step_end_time'] - train_state['accumulated_submission_time'] += step_time - # Log training progress periodically - if global_step % 10 == 0: - logging.info( - f'Step: {global_step}, ' - f'\tLast step time: {step_time:.4f}s, ' - f'\tTotal time: {train_state["accumulated_submission_time"]:.2f}s' - ) + + train_state['accumulated_submission_time'] += ( + train_step_end_time - train_state['last_step_end_time'] + ) # Check if submission is eligible for an untimed eval. if ( @@ -518,19 +512,10 @@ def train_once( latest_eval_result['accumulated_logging_time'] = train_state[ 'accumulated_logging_time' ] - # Calculate average per-step time - avg_per_step_time = ( - train_state['accumulated_submission_time'] / global_step - if global_step > 0 - else 0.0 - ) - latest_eval_result['avg_per_step_time'] = avg_per_step_time time_since_start = latest_eval_result['total_duration'] logging.info( f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, ' - f'\tAvg per-step time: {avg_per_step_time:.4f}s, ' - f'\t{latest_eval_result}' + f'\tStep: {global_step}, \t{latest_eval_result}' ) eval_results.append((global_step, latest_eval_result)) From c9899cfd25f57a8fd9ea32f7d74006ea2548ebf3 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 11 Dec 2025 23:47:19 +0000 Subject: [PATCH 8/8] Use tf32 in pytorch --- algoperf/pytorch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..937001b87 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -20,6 +20,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: + torch.set_float32_matmul_precision('high') use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')