From 6afa518fb5d12aa4cebc02a93a5122b3e0a63cdf Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 17 Dec 2025 16:21:06 -0800 Subject: [PATCH 1/2] [slim tensor migration 1/n] introduce slimtensor required c10 functions This stack aims to migrate slim tensor into ExecuTorch stack to make it as internal tensor representation of cudabackend. This diff introduce slimtensor required c10 dependencies into ExecuTorch by copy and paste c10 headers slim tensor needs but not show in the ExecuTorch stack. Note that to support slimtensor first, in this diff we just copy and paste required c10 files, but not making it the same as current c10 in pytorch. We will try to sync it with latest c10 and move them into `executorch/runtime/core/portable_type/c10/c10/` after slim tensor migration done. Differential Revision: [D89417354](https://our.internmc.facebook.com/intern/diff/D89417354/) [ghstack-poisoned] --- backends/aoti/slim/c10/Contiguity.h | 162 +++++++++ backends/aoti/slim/c10/MemoryFormat.h | 260 ++++++++++++++ backends/aoti/slim/c10/SizesAndStrides.h | 415 +++++++++++++++++++++++ backends/aoti/slim/c10/TARGETS | 3 + backends/aoti/slim/c10/WrapDimMinimal.h | 80 +++++ backends/aoti/slim/c10/targets.bzl | 25 ++ 6 files changed, 945 insertions(+) create mode 100644 backends/aoti/slim/c10/Contiguity.h create mode 100644 backends/aoti/slim/c10/MemoryFormat.h create mode 100644 backends/aoti/slim/c10/SizesAndStrides.h create mode 100644 backends/aoti/slim/c10/TARGETS create mode 100644 backends/aoti/slim/c10/WrapDimMinimal.h create mode 100644 backends/aoti/slim/c10/targets.bzl diff --git a/backends/aoti/slim/c10/Contiguity.h b/backends/aoti/slim/c10/Contiguity.h new file mode 100644 index 00000000000..929857fccda --- /dev/null +++ b/backends/aoti/slim/c10/Contiguity.h @@ -0,0 +1,162 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +using ::executorch::runtime::ArrayRef; + +template +bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (numel == 0) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (size_d == 1) { + continue; + } + + if (strides[d] != expected_stride) { + return false; + } + expected_stride *= size_d; + } + return true; +} + +// This function will return True if the tensor is contiguous, and False if the +// its not or if we can't determine if it is contiguous due to unbacked symbols +// (it could be either in that case based on the actual runtime data). +template +bool definitely_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (numel == 0) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (size_d == 1) { + continue; + } + + if (strides[d] != expected_stride) { + return false; + } + expected_stride *= size_d; + } + return true; +} + +template +bool _compute_channels_last_contiguous_2d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 4: { + T expected = 1; + for (auto& d : {1, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (size_d != 1) { + if (strides[d] != expected) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +bool _compute_channels_last_contiguous_3d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 5: { + T expected = 1; + for (auto& d : {1, 4, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (size_d != 1) { + if (strides[d] != expected) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +bool _compute_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + auto dim = sizes.size(); + if (dim == 1) { + return sizes[0] < 2 || strides[0] == 1; + } + std::vector perm(dim); + for (const auto i : c10::irange(dim)) { + perm[i] = i; + } + // Sort by strides, leaving 0 and 1 sized dims at the end of the array + std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { + if (sizes[a] < 2) { + return false; + } else if (sizes[b] < 2) { + return true; + } + return strides[a] < strides[b]; + }); + T require_stride = 1; + for (const auto i : c10::irange(dim)) { + const auto& size_perm_i = sizes[perm[i]]; + if (size_perm_i < 2) { + return true; + } + if (strides[perm[i]] != require_stride) { + return false; + } + require_stride *= size_perm_i; + } + return true; +} + +} // namespace c10 diff --git a/backends/aoti/slim/c10/MemoryFormat.h b/backends/aoti/slim/c10/MemoryFormat.h new file mode 100644 index 00000000000..e5c155ce58e --- /dev/null +++ b/backends/aoti/slim/c10/MemoryFormat.h @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +#include +#include +#include + +// Memory format is not the property of a Tensor. It is the way to tell an +// operator how the result should be organized in memory and nothing more. That +// means memory format should never be used as return value for any tensor state +// interrogation functions (internally and externally). +// +// Possible options are: +// Preserve: +// If any of the input tensors is in channels_last format, operator output +// should be in channels_last format +// +// Contiguous: +// Regardless of input tensors format, the output should be contiguous +// Tensor. +// +// ChannelsLast: +// Regardless of input tensors format, the output should be in channels_last +// format. + +namespace c10 { + +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::IntArrayRef; + +enum class MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d, + NumOptions +}; + +// If you are seeing this, it means that this call site was not checked if +// the memory format could be preserved, and it was switched to old default +// behaviour of contiguous +#define LEGACY_CONTIGUOUS_MEMORY_FORMAT ::c10::get_contiguous_memory_format() + +inline MemoryFormat get_contiguous_memory_format() { + return MemoryFormat::Contiguous; +} + +inline std::ostream& operator<<( + std::ostream& stream, + MemoryFormat memory_format) { + switch (memory_format) { + case MemoryFormat::Preserve: + return stream << "Preserve"; + case MemoryFormat::Contiguous: + return stream << "Contiguous"; + case MemoryFormat::ChannelsLast: + return stream << "ChannelsLast"; + case MemoryFormat::ChannelsLast3d: + return stream << "ChannelsLast3d"; + default: + ET_CHECK_MSG( + false, + "Unknown memory format %d", + static_cast(memory_format)); + } +} + +// Note: Hardcoded the channel last stride indices here to get better +// performance +template +inline std::vector get_channels_last_strides_2d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 4: + strides[1] = 1; + strides[3] = sizes[1]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 3: + strides[0] = 1; + strides[2] = sizes[0]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + ET_CHECK_MSG( + false, + "ChannelsLast2d doesn't support size %zu", + static_cast(sizes.size())); + } +} + +inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { + return get_channels_last_strides_2d(sizes); +} + +template +std::vector get_channels_last_strides_3d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 5: + strides[1] = 1; + strides[4] = sizes[1]; + strides[3] = strides[4] * sizes[4]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 4: + strides[0] = 1; + strides[3] = sizes[0]; + strides[2] = strides[3] * sizes[3]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + ET_CHECK_MSG( + false, + "ChannelsLast3d doesn't support size %zu", + static_cast(sizes.size())); + } +} + +inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { + return get_channels_last_strides_3d(sizes); +} + +// NOTE: +// Below are Helper functions for is_channels_last_strides_xd. +// 1. Please do not combine these helper functions, each helper function handles +// exactly one case of sizes + memory_format, by doing this, the strides indices +// will be a constant array and we can access it using constant index number, +// the compiler will fully unroll the loop on strides indices to gain a better +// performance. +// 2. No error check in helper function, caller ensures the correctness of the +// input +// 3. All helper functions have similar comments, only 1st helper function is +// commented here. +template +inline bool is_channels_last_strides_2d_s4( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + // special case for trivial C dimension. default to NCHW + if (strides[1] == 0) { + return false; + } + // loop strides indices + for (auto& d : {1, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + // Fallback to NCHW as default layout for ambiguous cases + // This is the flaw of implicit memory_format from strides. + // N111 tensor with identical strides for size 1 dimension; + // Two cases could lead us here: + // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + // b. N11W contiguous Tensor sliced on the W-dimension. + // ([N,1,1,1]@[W,W,W,W]) + if (d == 0 && min == strides[1]) { + return false; + } + // This is necessary to: + // 1. distinguish the memory_format of N1H1; + // [H, 1, 1, 1] channels_last stride + // [H, H, 1, 1] contiguous stride + // 2. permutation of 1C1W: + // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +template +inline bool is_channels_last_strides_3d_s5( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + if (strides[1] == 0) { + return false; + } + for (auto& d : {1, 4, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + if (d == 0 && min == strides[1]) { + return false; + } + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +template +inline bool is_channels_last_strides_2d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 4: + return is_channels_last_strides_2d_s4(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +inline bool is_channels_last_strides_3d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 5: + return is_channels_last_strides_3d_s5(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +inline bool is_channels_last_strides_2d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_2d(sizes, strides); +} + +inline bool is_channels_last_strides_3d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_3d(sizes, strides); +} + +} // namespace c10 diff --git a/backends/aoti/slim/c10/SizesAndStrides.h b/backends/aoti/slim/c10/SizesAndStrides.h new file mode 100644 index 00000000000..028097370e4 --- /dev/null +++ b/backends/aoti/slim/c10/SizesAndStrides.h @@ -0,0 +1,415 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + +namespace c10 { + +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::IntArrayRef; + +// Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer +// to out-of-line array +class SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + SizesAndStrides() { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + } + + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (C10_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + bool operator==(const SizesAndStrides& other) const { + if (size_ != other.size_) { + return false; + } + return !( + isInline() + ? std::memcmp( + inlineStorage_, + other.inlineStorage_, + sizeof(inlineStorage_)) + : std::memcmp( + outOfLineStorage_, + other.outOfLineStorage_, + storageBytes(size_))); + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (C10_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; + } + + size_t size() const noexcept { + return size_; + } + + const int64_t* sizes_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + int64_t* sizes_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { + return sizes_data(); + } + + sizes_iterator sizes_begin() noexcept { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const noexcept { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() noexcept { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const noexcept { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + void set_strides(IntArrayRef strides) { + ET_DCHECK_MSG( + strides.size() == size(), + "strides size %zu must match current size %zu", + strides.size(), + size()); + std::copy(strides.begin(), strides.end(), strides_begin()); + } + + const int64_t* strides_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + int64_t* strides_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_begin() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_iterator strides_begin() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_end() const noexcept { + return strides_begin() + size(); + } + + strides_iterator strides_end() noexcept { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const noexcept { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; + } + + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; + } + + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (C10_LIKELY( + newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], + 0, + bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } + } + + private: + void resizeSlowPath(size_t newSize, size_t oldSize) { + if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { + ET_DCHECK_MSG( + !isInline(), + "resizeSlowPath called when fast path should have been hit!"); + int64_t* tempStorage = outOfLineStorage_; + memcpy( + &inlineStorage_[0], + &tempStorage[0], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + memcpy( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &tempStorage[oldSize], + C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ + // HAS BEEN OVERWRITTEN! + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(tempStorage); + } else { + if (isInline()) { + // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD + // OVERWRITE inlineStorage_! + int64_t* tempStorage = + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + static_cast(malloc(storageBytes(newSize))); + ET_CHECK_MSG( + tempStorage, + "Could not allocate memory to change Tensor SizesAndStrides!"); + const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); + const auto bytesToZero = (newSize > oldSize) + ? (newSize - oldSize) * sizeof(tempStorage[0]) + : 0; + memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[oldSize], 0, bytesToZero); + } + memcpy( + &tempStorage[newSize], + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[newSize + oldSize], 0, bytesToZero); + } + outOfLineStorage_ = tempStorage; + } else { + const bool isGrowing = oldSize < newSize; + if (isGrowing) { + // Resize before shifting so that we have room. + resizeOutOfLineStorage(newSize); + } + // Shift the old strides to their new starting point. Note + // that this does not occur in the inline path above because + // the stride starting point is not moving. + memmove( + outOfLineStorage_ + newSize, + outOfLineStorage_ + oldSize, + std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0])); + if (!isGrowing) { + // Resize after shifting so that we don't lose data. + resizeOutOfLineStorage(newSize); + } else { + // Zero the end of the sizes portion. + const auto bytesToZero = + (newSize - oldSize) * sizeof(outOfLineStorage_[0]); + memset(&outOfLineStorage_[oldSize], 0, bytesToZero); + memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); + } + } + } + size_ = newSize; + } + + bool isInline() const noexcept { + return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + ET_DCHECK_MSG(rhs.isInline(), "rhs must be inline"); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + ET_CHECK_MSG( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + ET_DCHECK_MSG(!isInline(), "must not be inline"); + outOfLineStorage_ = static_cast( + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + realloc(outOfLineStorage_, storageBytes(newSize))); + ET_CHECK_MSG( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_{1}; + union { + int64_t* outOfLineStorage_; + // NOLINTNEXTLINE(*c-array*) + int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; +}; + +} // namespace c10 diff --git a/backends/aoti/slim/c10/TARGETS b/backends/aoti/slim/c10/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/slim/c10/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/slim/c10/WrapDimMinimal.h b/backends/aoti/slim/c10/WrapDimMinimal.h new file mode 100644 index 00000000000..f40d2986dbc --- /dev/null +++ b/backends/aoti/slim/c10/WrapDimMinimal.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +// Different from the original implementation in c10, we don't need +// to support SymInt here. +namespace c10 { +namespace detail { +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); +} + +template +T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { + // Inline the fast paths + if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { + // For SymInts, we want an explicit control flow to trigger a guard, so we + // may as well branch too. + if (dim < 0) { + return dim + dim_post_expr; + } + return dim; + } + // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) + return c10::detail::maybe_wrap_dim_slow( + std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +inline int64_t +maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar = true) { + return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); +} + +namespace detail { +// This template can only be specialized at int64_t and c10::SymInt; +// you'll get linker errors otherwise +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { + ET_CHECK_MSG( + dim_post_expr >= 0, + "Rank cannot be negative but got %" PRId64, + static_cast(dim_post_expr)); + + if (dim_post_expr == 0) { + ET_CHECK_MSG( + wrap_scalar, + "Dimension specified as %" PRId64 " but tensor has no dimensions", + static_cast(dim)); + return c10::maybe_wrap_dim( + std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false); + } + + T min = dim_post_expr * -1; + T max = dim_post_expr - 1; + ET_CHECK_MSG( + min <= dim && dim <= max, + "Dimension out of range (expected to be in range of [%" PRId64 + ", %" PRId64 "], but got %" PRId64 ")", + static_cast(min), + static_cast(max), + static_cast(dim)); + + ET_DCHECK_MSG( + false, "should never reach here as dim should be out-of-bounds"); + return dim; // unreachable, but needed to suppress compiler warnings +} +} // namespace detail +} // namespace c10 diff --git a/backends/aoti/slim/c10/targets.bzl b/backends/aoti/slim/c10/targets.bzl new file mode 100644 index 00000000000..f12ab1009d8 --- /dev/null +++ b/backends/aoti/slim/c10/targets.bzl @@ -0,0 +1,25 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define c10 core targets for SlimTensor. + + These headers provide c10 APIs needed by SlimTensor that are not + available in ExecuTorch's c10 directory (which is synced from PyTorch). + """ + + runtime.cxx_library( + name = "c10_core", + exported_headers = [ + "Contiguity.h", + "MemoryFormat.h", + "SizesAndStrides.h", + "WrapDimMinimal.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ] + ([] if runtime.is_oss else [ + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", + ]), + ) From bde357d7f5e8757e6c2748b4197a129ec33a77f0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 17 Dec 2025 22:10:19 -0800 Subject: [PATCH 2/2] Update on "[slim tensor migration 1/n] copy and paste slim tensor into ExecuTorch" This stack aims to migrate slim tensor into ExecuTorch stack to make it as internal tensor representation of cudabackend. This diff aims to copy and paste slimtensor class into ExecuTorch codebase, with ZERO change on logic, namespace, class, etc. The only change would be the include path. In the diffs above it I will gradually update the slim tensor class to make it suitable for executorch, and reviewer friendly. This diff will not be landed until all updates have been done. Differential Revision: [D89417354](https://our.internmc.facebook.com/intern/diff/D89417354/) [ghstack-poisoned] --- backends/aoti/slim/CMakeLists.txt | 102 ++ backends/aoti/slim/TARGETS | 5 + backends/aoti/slim/c10/TARGETS | 4 +- backends/aoti/slim/c10/WrapDimMinimal.h | 18 +- .../aoti/slim/c10/{ => core}/Contiguity.h | 23 +- backends/aoti/slim/c10/core/Device.h | 372 ++++++++ backends/aoti/slim/c10/core/DeviceType.h | 134 +++ backends/aoti/slim/c10/core/Layout.h | 53 ++ .../aoti/slim/c10/{ => core}/MemoryFormat.h | 91 +- backends/aoti/slim/c10/core/Scalar.h | 360 +++++++ backends/aoti/slim/c10/core/ScalarType.h | 735 ++++++++++++++ .../slim/c10/{ => core}/SizesAndStrides.h | 105 +- backends/aoti/slim/c10/core/WrapDimMinimal.h | 73 ++ backends/aoti/slim/c10/cuda/Exception.h | 29 + backends/aoti/slim/c10/macros/Macros.h | 219 +++++ backends/aoti/slim/c10/targets.bzl | 38 +- backends/aoti/slim/c10/util/Array.h | 18 + backends/aoti/slim/c10/util/ArrayRef.h | 371 ++++++++ backends/aoti/slim/c10/util/BFloat16-inl.h | 365 +++++++ backends/aoti/slim/c10/util/BFloat16-math.h | 332 +++++++ backends/aoti/slim/c10/util/BFloat16.h | 123 +++ backends/aoti/slim/c10/util/Exception.h | 87 ++ .../aoti/slim/c10/util/Float4_e2m1fn_x2.h | 28 + .../aoti/slim/c10/util/Float8_e4m3fn-inl.h | 297 ++++++ backends/aoti/slim/c10/util/Float8_e4m3fn.h | 238 +++++ .../aoti/slim/c10/util/Float8_e4m3fnuz-inl.h | 312 ++++++ backends/aoti/slim/c10/util/Float8_e4m3fnuz.h | 138 +++ backends/aoti/slim/c10/util/Float8_e5m2-inl.h | 302 ++++++ backends/aoti/slim/c10/util/Float8_e5m2.h | 147 +++ .../aoti/slim/c10/util/Float8_e5m2fnuz-inl.h | 318 +++++++ backends/aoti/slim/c10/util/Float8_e5m2fnuz.h | 138 +++ .../aoti/slim/c10/util/Float8_e8m0fnu-inl.h | 118 +++ backends/aoti/slim/c10/util/Float8_e8m0fnu.h | 119 +++ backends/aoti/slim/c10/util/Float8_fnuz_cvt.h | 64 ++ backends/aoti/slim/c10/util/Half-inl.h | 351 +++++++ backends/aoti/slim/c10/util/Half.h | 424 +++++++++ backends/aoti/slim/c10/util/StringUtil.h | 16 + backends/aoti/slim/c10/util/TypeCast.h | 236 +++++ .../aoti/slim/c10/util/TypeSafeSignMath.h | 141 +++ backends/aoti/slim/c10/util/accumulate.h | 125 +++ backends/aoti/slim/c10/util/bit_cast.h | 44 + backends/aoti/slim/c10/util/bits.h | 61 ++ backends/aoti/slim/c10/util/complex.h | 690 ++++++++++++++ backends/aoti/slim/c10/util/complex_math.h | 500 ++++++++++ backends/aoti/slim/c10/util/complex_utils.h | 46 + backends/aoti/slim/c10/util/copysign.h | 26 + .../aoti/slim/c10/util/floating_point_utils.h | 33 + backends/aoti/slim/c10/util/generic_math.h | 105 ++ backends/aoti/slim/c10/util/irange.h | 123 +++ backends/aoti/slim/c10/util/llvmMathExtras.h | 899 ++++++++++++++++++ backends/aoti/slim/c10/util/overflows.h | 100 ++ backends/aoti/slim/c10/util/qint32.h | 18 + backends/aoti/slim/c10/util/qint8.h | 20 + backends/aoti/slim/c10/util/quint2x4.h | 19 + backends/aoti/slim/c10/util/quint4x2.h | 19 + backends/aoti/slim/c10/util/quint8.h | 18 + backends/aoti/slim/c10/util/safe_numerics.h | 94 ++ backends/aoti/slim/core/SlimTensor.h | 637 +++++++++++++ .../aoti/slim/core/SlimTensorResize-incl.h | 174 ++++ backends/aoti/slim/core/SlimTensorView-incl.h | 152 +++ backends/aoti/slim/core/Storage.h | 307 ++++++ backends/aoti/slim/cuda/Exception.h | 39 + backends/aoti/slim/cuda/Guard.h | 174 ++++ backends/aoti/slim/factory/Empty.h | 35 + backends/aoti/slim/factory/Factory.h | 32 + backends/aoti/slim/factory/FromBlob.h | 36 + backends/aoti/slim/factory/FromScalar.h | 15 + backends/aoti/slim/factory/Pad.h | 106 +++ backends/aoti/slim/targets.bzl | 81 ++ backends/aoti/slim/tests/TARGETS | 5 + backends/aoti/slim/tests/targets.bzl | 31 + .../slim/tests/test_slim_tensor_basic.cpp | 170 ++++ .../aoti/slim/tests/test_slim_tensor_cuda.cpp | 212 +++++ .../aoti/slim/tests/test_type_convert.cpp | 83 ++ backends/aoti/slim/util/SharedPtr.h | 222 +++++ backends/aoti/slim/util/SizeUtil.h | 283 ++++++ 76 files changed, 12646 insertions(+), 132 deletions(-) create mode 100644 backends/aoti/slim/CMakeLists.txt create mode 100644 backends/aoti/slim/TARGETS rename backends/aoti/slim/c10/{ => core}/Contiguity.h (88%) create mode 100644 backends/aoti/slim/c10/core/Device.h create mode 100644 backends/aoti/slim/c10/core/DeviceType.h create mode 100644 backends/aoti/slim/c10/core/Layout.h rename backends/aoti/slim/c10/{ => core}/MemoryFormat.h (68%) create mode 100644 backends/aoti/slim/c10/core/Scalar.h create mode 100644 backends/aoti/slim/c10/core/ScalarType.h rename backends/aoti/slim/c10/{ => core}/SizesAndStrides.h (79%) create mode 100644 backends/aoti/slim/c10/core/WrapDimMinimal.h create mode 100644 backends/aoti/slim/c10/cuda/Exception.h create mode 100644 backends/aoti/slim/c10/macros/Macros.h create mode 100644 backends/aoti/slim/c10/util/Array.h create mode 100644 backends/aoti/slim/c10/util/ArrayRef.h create mode 100644 backends/aoti/slim/c10/util/BFloat16-inl.h create mode 100644 backends/aoti/slim/c10/util/BFloat16-math.h create mode 100644 backends/aoti/slim/c10/util/BFloat16.h create mode 100644 backends/aoti/slim/c10/util/Exception.h create mode 100644 backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h create mode 100644 backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h create mode 100644 backends/aoti/slim/c10/util/Float8_e4m3fn.h create mode 100644 backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h create mode 100644 backends/aoti/slim/c10/util/Float8_e4m3fnuz.h create mode 100644 backends/aoti/slim/c10/util/Float8_e5m2-inl.h create mode 100644 backends/aoti/slim/c10/util/Float8_e5m2.h create mode 100644 backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h create mode 100644 backends/aoti/slim/c10/util/Float8_e5m2fnuz.h create mode 100644 backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h create mode 100644 backends/aoti/slim/c10/util/Float8_e8m0fnu.h create mode 100644 backends/aoti/slim/c10/util/Float8_fnuz_cvt.h create mode 100644 backends/aoti/slim/c10/util/Half-inl.h create mode 100644 backends/aoti/slim/c10/util/Half.h create mode 100644 backends/aoti/slim/c10/util/StringUtil.h create mode 100644 backends/aoti/slim/c10/util/TypeCast.h create mode 100644 backends/aoti/slim/c10/util/TypeSafeSignMath.h create mode 100644 backends/aoti/slim/c10/util/accumulate.h create mode 100644 backends/aoti/slim/c10/util/bit_cast.h create mode 100644 backends/aoti/slim/c10/util/bits.h create mode 100644 backends/aoti/slim/c10/util/complex.h create mode 100644 backends/aoti/slim/c10/util/complex_math.h create mode 100644 backends/aoti/slim/c10/util/complex_utils.h create mode 100644 backends/aoti/slim/c10/util/copysign.h create mode 100644 backends/aoti/slim/c10/util/floating_point_utils.h create mode 100644 backends/aoti/slim/c10/util/generic_math.h create mode 100644 backends/aoti/slim/c10/util/irange.h create mode 100644 backends/aoti/slim/c10/util/llvmMathExtras.h create mode 100644 backends/aoti/slim/c10/util/overflows.h create mode 100644 backends/aoti/slim/c10/util/qint32.h create mode 100644 backends/aoti/slim/c10/util/qint8.h create mode 100644 backends/aoti/slim/c10/util/quint2x4.h create mode 100644 backends/aoti/slim/c10/util/quint4x2.h create mode 100644 backends/aoti/slim/c10/util/quint8.h create mode 100644 backends/aoti/slim/c10/util/safe_numerics.h create mode 100644 backends/aoti/slim/core/SlimTensor.h create mode 100644 backends/aoti/slim/core/SlimTensorResize-incl.h create mode 100644 backends/aoti/slim/core/SlimTensorView-incl.h create mode 100644 backends/aoti/slim/core/Storage.h create mode 100644 backends/aoti/slim/cuda/Exception.h create mode 100644 backends/aoti/slim/cuda/Guard.h create mode 100644 backends/aoti/slim/factory/Empty.h create mode 100644 backends/aoti/slim/factory/Factory.h create mode 100644 backends/aoti/slim/factory/FromBlob.h create mode 100644 backends/aoti/slim/factory/FromScalar.h create mode 100644 backends/aoti/slim/factory/Pad.h create mode 100644 backends/aoti/slim/targets.bzl create mode 100644 backends/aoti/slim/tests/TARGETS create mode 100644 backends/aoti/slim/tests/targets.bzl create mode 100644 backends/aoti/slim/tests/test_slim_tensor_basic.cpp create mode 100644 backends/aoti/slim/tests/test_slim_tensor_cuda.cpp create mode 100644 backends/aoti/slim/tests/test_type_convert.cpp create mode 100644 backends/aoti/slim/util/SharedPtr.h create mode 100644 backends/aoti/slim/util/SizeUtil.h diff --git a/backends/aoti/slim/CMakeLists.txt b/backends/aoti/slim/CMakeLists.txt new file mode 100644 index 00000000000..b14d47f15c8 --- /dev/null +++ b/backends/aoti/slim/CMakeLists.txt @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +# SlimTensor library for ExecuTorch CUDA backend A lightweight tensor +# implementation for AOTI (Ahead-of-Time Inference) + +# C10 core headers +set(SLIM_C10_HEADERS + c10/core/Device.h c10/core/DeviceType.h c10/Contiguity.h c10/MemoryFormat.h + c10/SizesAndStrides.h c10/WrapDimMinimal.h +) + +# Utility headers +set(SLIM_UTIL_HEADERS util/SharedPtr.h util/SizeUtil.h util/type_convert.h) + +# Core SlimTensor headers +set(SLIM_CORE_HEADERS core/SlimTensor.h core/SlimTensorResize-incl.h + core/SlimTensorView-incl.h core/Storage.h +) + +# Factory headers +set(SLIM_FACTORY_HEADERS factory/Empty.h factory/Factory.h factory/FromBlob.h + factory/FromScalar.h factory/Pad.h +) + +# CUDA headers +set(SLIM_CUDA_HEADERS cuda/Exception.h cuda/Guard.h) + +# All headers combined +set(SLIM_TENSOR_HEADERS + ${SLIM_C10_HEADERS} ${SLIM_UTIL_HEADERS} ${SLIM_CORE_HEADERS} + ${SLIM_FACTORY_HEADERS} ${SLIM_CUDA_HEADERS} +) + +# Header-only interface library for SlimTensor +add_library(slim_tensor INTERFACE) +target_include_directories( + slim_tensor INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../../../.. +) + +# Link to ExecuTorch dependencies +target_link_libraries( + slim_tensor INTERFACE executorch_core extension_data_loader +) + +# CUDA support (if available) +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(slim_tensor INTERFACE CUDA::cudart) +endif() + +# Installation +install(FILES ${SLIM_C10_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/c10/core +) +install(FILES c10/Contiguity.h c10/MemoryFormat.h c10/SizesAndStrides.h + c10/WrapDimMinimal.h + DESTINATION include/executorch/backends/aoti/slim/c10 +) +install(FILES ${SLIM_UTIL_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/util +) +install(FILES ${SLIM_CORE_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/core +) +install(FILES ${SLIM_FACTORY_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/factory +) +install(FILES ${SLIM_CUDA_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/cuda +) + +# Tests (if building tests) +if(EXECUTORCH_BUILD_TESTS) + enable_testing() + + # Basic SlimTensor tests + add_executable(test_slim_tensor_basic tests/test_slim_tensor_basic.cpp) + target_link_libraries( + test_slim_tensor_basic PRIVATE slim_tensor gtest gtest_main + ) + add_test(NAME test_slim_tensor_basic COMMAND test_slim_tensor_basic) + + # Type conversion tests + add_executable(test_type_convert tests/test_type_convert.cpp) + target_link_libraries(test_type_convert PRIVATE slim_tensor gtest gtest_main) + add_test(NAME test_type_convert COMMAND test_type_convert) + + # CUDA tests (if CUDA is enabled) + if(EXECUTORCH_BUILD_CUDA) + add_executable(test_slim_tensor_cuda tests/test_slim_tensor_cuda.cpp) + target_link_libraries( + test_slim_tensor_cuda PRIVATE slim_tensor gtest gtest_main CUDA::cudart + ) + add_test(NAME test_slim_tensor_cuda COMMAND test_slim_tensor_cuda) + endif() +endif() diff --git a/backends/aoti/slim/TARGETS b/backends/aoti/slim/TARGETS new file mode 100644 index 00000000000..0a42614a385 --- /dev/null +++ b/backends/aoti/slim/TARGETS @@ -0,0 +1,5 @@ +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/c10/TARGETS b/backends/aoti/slim/c10/TARGETS index 77871de4469..0a42614a385 100644 --- a/backends/aoti/slim/c10/TARGETS +++ b/backends/aoti/slim/c10/TARGETS @@ -1,3 +1,5 @@ -load("targets.bzl", "define_common_targets") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") define_common_targets() diff --git a/backends/aoti/slim/c10/WrapDimMinimal.h b/backends/aoti/slim/c10/WrapDimMinimal.h index f40d2986dbc..d0b51ff762b 100644 --- a/backends/aoti/slim/c10/WrapDimMinimal.h +++ b/backends/aoti/slim/c10/WrapDimMinimal.h @@ -50,14 +50,14 @@ template T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { ET_CHECK_MSG( dim_post_expr >= 0, - "Rank cannot be negative but got %" PRId64, - static_cast(dim_post_expr)); + "Rank cannot be negative but got %lld", + static_cast(dim_post_expr)); if (dim_post_expr == 0) { ET_CHECK_MSG( wrap_scalar, - "Dimension specified as %" PRId64 " but tensor has no dimensions", - static_cast(dim)); + "Dimension specified as %lld but tensor has no dimensions", + static_cast(dim)); return c10::maybe_wrap_dim( std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false); } @@ -66,11 +66,11 @@ T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { T max = dim_post_expr - 1; ET_CHECK_MSG( min <= dim && dim <= max, - "Dimension out of range (expected to be in range of [%" PRId64 - ", %" PRId64 "], but got %" PRId64 ")", - static_cast(min), - static_cast(max), - static_cast(dim)); + "Dimension out of range (expected to be in range of [%lld" + ", %lld], but got %lld)", + static_cast(min), + static_cast(max), + static_cast(dim)); ET_DCHECK_MSG( false, "should never reach here as dim should be out-of-bounds"); diff --git a/backends/aoti/slim/c10/Contiguity.h b/backends/aoti/slim/c10/core/Contiguity.h similarity index 88% rename from backends/aoti/slim/c10/Contiguity.h rename to backends/aoti/slim/c10/core/Contiguity.h index 929857fccda..d5ff49561ab 100644 --- a/backends/aoti/slim/c10/Contiguity.h +++ b/backends/aoti/slim/c10/core/Contiguity.h @@ -1,23 +1,12 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - #pragma once -#include -#include +#include +#include #include #include -#include - -namespace c10 { -using ::executorch::runtime::ArrayRef; +namespace standalone::c10 { template bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { @@ -133,7 +122,7 @@ bool _compute_non_overlapping_and_dense( return sizes[0] < 2 || strides[0] == 1; } std::vector perm(dim); - for (const auto i : c10::irange(dim)) { + for (const auto i : irange(dim)) { perm[i] = i; } // Sort by strides, leaving 0 and 1 sized dims at the end of the array @@ -146,7 +135,7 @@ bool _compute_non_overlapping_and_dense( return strides[a] < strides[b]; }); T require_stride = 1; - for (const auto i : c10::irange(dim)) { + for (const auto i : irange(dim)) { const auto& size_perm_i = sizes[perm[i]]; if (size_perm_i < 2) { return true; @@ -159,4 +148,4 @@ bool _compute_non_overlapping_and_dense( return true; } -} // namespace c10 +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/Device.h b/backends/aoti/slim/c10/core/Device.h new file mode 100644 index 00000000000..a9a6d3a8136 --- /dev/null +++ b/backends/aoti/slim/c10/core/Device.h @@ -0,0 +1,372 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Copied from c10/core/DeviceType.h with some modifications + +namespace standalone::c10 { +namespace detail { +enum class DeviceStringParsingState { + kSTART, + kINDEX_START, + kINDEX_REST, + kERROR +}; + +inline DeviceType parse_type(const std::string& device_string) { + static const std::array< + std::pair, + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> + types = {{ + {"cpu", DeviceType::CPU}, + {"cuda", DeviceType::CUDA}, + {"ipu", DeviceType::IPU}, + {"xpu", DeviceType::XPU}, + {"mkldnn", DeviceType::MKLDNN}, + {"opengl", DeviceType::OPENGL}, + {"opencl", DeviceType::OPENCL}, + {"ideep", DeviceType::IDEEP}, + {"hip", DeviceType::HIP}, + {"ve", DeviceType::VE}, + {"fpga", DeviceType::FPGA}, + {"maia", DeviceType::MAIA}, + {"xla", DeviceType::XLA}, + {"lazy", DeviceType::Lazy}, + {"vulkan", DeviceType::Vulkan}, + {"mps", DeviceType::MPS}, + {"meta", DeviceType::Meta}, + {"hpu", DeviceType::HPU}, + {"mtia", DeviceType::MTIA}, + {"privateuseone", DeviceType::PrivateUse1}, + }}; + auto device = std::find_if( + types.begin(), + types.end(), + [&device_string](const std::pair& p) { + return p.first && p.first == device_string; + }); + if (device != types.end()) { + return device->second; + } + if (device_string == get_privateuse1_backend()) { + return DeviceType::PrivateUse1; + } + std::vector device_names; + for (const auto& it : types) { + if (it.first) { + device_names.push_back(it.first); + } + } + STANDALONE_CHECK( + false, + "Expected one of ", + Join(", ", device_names), + " device type at start of device string: ", + device_string); +} +} // namespace detail + +/// An index representing a specific device; e.g., the 1 in GPU 1. +/// A DeviceIndex is not independently meaningful without knowing +/// the DeviceType it is associated; try to use Device rather than +/// DeviceIndex directly. +using DeviceIndex = int8_t; + +/// Represents a compute device on which a tensor is located. A device is +/// uniquely identified by a type, which specifies the type of machine it is +/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the +/// specific compute device when there is more than one of a certain type. The +/// device index is optional, and in its defaulted state represents (abstractly) +/// "the current device". Further, there are two constraints on the value of the +/// device index, if one is explicitly stored: +/// 1. A negative index represents the current device, a non-negative index +/// represents a specific, concrete device, +/// 2. When the device type is CPU, the device index must be zero. +struct Device final { + using Type = DeviceType; + + /// Constructs a new `Device` from a `DeviceType` and an optional device + /// index. + /* implicit */ + Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) { + validate(); + } + + /// Constructs a `Device` from a string description, for convenience. + /// The string supplied must follow the following schema: + /// `(cpu|cuda)[:]` + /// where `cpu` or `cuda` specifies the device type, and + /// `:` optionally specifies a device index. + /* implicit */ Device(const std::string& device_string) : Device(Type::CPU) { + STANDALONE_CHECK(!device_string.empty(), "Device string must not be empty"); + + std::string device_name, device_index_str; + detail::DeviceStringParsingState pstate = + detail::DeviceStringParsingState::kSTART; + + // The code below tries to match the string in the variable + // device_string against the regular expression: + // ([a-zA-Z_]+)(?::([1-9]\\d*|0))? + for (size_t i = 0; pstate != detail::DeviceStringParsingState::kERROR && + i < device_string.size(); + ++i) { + const char ch = device_string.at(i); + const unsigned char uch = static_cast(ch); + switch (pstate) { + case detail::DeviceStringParsingState::kSTART: + if (ch != ':') { + if (std::isalpha(uch) || ch == '_') { + device_name.push_back(ch); + } else { + pstate = detail::DeviceStringParsingState::kERROR; + } + } else { + pstate = detail::DeviceStringParsingState::kINDEX_START; + } + break; + + case detail::DeviceStringParsingState::kINDEX_START: + if (std::isdigit(uch)) { + device_index_str.push_back(ch); + pstate = detail::DeviceStringParsingState::kINDEX_REST; + } else { + pstate = detail::DeviceStringParsingState::kERROR; + } + break; + + case detail::DeviceStringParsingState::kINDEX_REST: + if (device_index_str.at(0) == '0') { + pstate = detail::DeviceStringParsingState::kERROR; + break; + } + if (std::isdigit(uch)) { + device_index_str.push_back(ch); + } else { + pstate = detail::DeviceStringParsingState::kERROR; + } + break; + + case detail::DeviceStringParsingState::kERROR: + // Execution won't reach here. + break; + } + } + + const bool has_error = device_name.empty() || + pstate == detail::DeviceStringParsingState::kERROR || + (pstate == detail::DeviceStringParsingState::kINDEX_START && + device_index_str.empty()); + + STANDALONE_CHECK( + !has_error, "Invalid device string: '", device_string, "'"); + + try { + if (!device_index_str.empty()) { + index_ = static_cast(std::stoi(device_index_str)); + } + } catch (const std::exception&) { + STANDALONE_CHECK( + false, + "Could not parse device index '", + device_index_str, + "' in device string '", + device_string, + "'"); + } + type_ = detail::parse_type(device_name); + validate(); + } + + /// Returns true if the type and index of this `Device` matches that of + /// `other`. + bool operator==(const Device& other) const noexcept { + return this->type_ == other.type_ && this->index_ == other.index_; + } + + /// Returns true if the type or index of this `Device` differs from that of + /// `other`. + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + /// Sets the device index. + void set_index(DeviceIndex index) { + index_ = index; + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the optional index. + DeviceIndex index() const noexcept { + return index_; + } + + /// Returns true if the device has a non-default index. + bool has_index() const noexcept { + return index_ != -1; + } + + /// Return true if the device is of CUDA type. + bool is_cuda() const noexcept { + return type_ == DeviceType::CUDA; + } + + /// Return true if the device is of PrivateUse1 type. + bool is_privateuseone() const noexcept { + return type_ == DeviceType::PrivateUse1; + } + + /// Return true if the device is of MPS type. + bool is_mps() const noexcept { + return type_ == DeviceType::MPS; + } + + /// Return true if the device is of HIP type. + bool is_hip() const noexcept { + return type_ == DeviceType::HIP; + } + + /// Return true if the device is of VE type. + bool is_ve() const noexcept { + return type_ == DeviceType::VE; + } + + /// Return true if the device is of XPU type. + bool is_xpu() const noexcept { + return type_ == DeviceType::XPU; + } + + /// Return true if the device is of IPU type. + bool is_ipu() const noexcept { + return type_ == DeviceType::IPU; + } + + /// Return true if the device is of XLA type. + bool is_xla() const noexcept { + return type_ == DeviceType::XLA; + } + + /// Return true if the device is of MTIA type. + bool is_mtia() const noexcept { + return type_ == DeviceType::MTIA; + } + + /// Return true if the device is of HPU type. + bool is_hpu() const noexcept { + return type_ == DeviceType::HPU; + } + + /// Return true if the device is of Lazy type. + bool is_lazy() const noexcept { + return type_ == DeviceType::Lazy; + } + + /// Return true if the device is of Vulkan type. + bool is_vulkan() const noexcept { + return type_ == DeviceType::Vulkan; + } + + /// Return true if the device is of Metal type. + bool is_metal() const noexcept { + return type_ == DeviceType::Metal; + } + + /// Return true if the device is of MAIA type. + bool is_maia() const noexcept { + return type_ == DeviceType::MAIA; + } + + /// Return true if the device is of META type. + bool is_meta() const noexcept { + return type_ == DeviceType::Meta; + } + + /// Return true if the device is of CPU type. + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } + + /// Return true if the device supports arbitrary strides. + bool supports_as_strided() const noexcept { + return type_ != DeviceType::IPU && type_ != DeviceType::XLA && + type_ != DeviceType::Lazy && type_ != DeviceType::MTIA; + } + + /// Same string as returned from operator<<. + std::string str() const { + std::string str = DeviceTypeName(type(), /* lower case */ true); + if (has_index()) { + str.push_back(':'); + str.append(std::to_string(index())); + } + return str; + } + + private: + DeviceType type_; + DeviceIndex index_ = -1; + void validate() { + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + index_ >= -1, + "Device index must be -1 or non-negative, got ", + static_cast(index_)); + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", + static_cast(index_)); + } +}; + +inline std::ostream& operator<<(std::ostream& stream, const Device& device) { + stream << device.str(); + return stream; +} +} // namespace standalone::c10 + +namespace std { +template <> +struct hash { + size_t operator()(standalone::c10::Device d) const noexcept { + // Are you here because this static assert failed? Make sure you ensure + // that the bitmasking code below is updated accordingly! + static_assert( + sizeof(standalone::c10::DeviceType) == 1, "DeviceType is not 8-bit"); + static_assert( + sizeof(standalone::c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); + // Note [Hazard when concatenating signed integers] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // We must first convert to a same-sized unsigned type, before promoting to + // the result type, to prevent sign extension when any of the values is -1. + // If sign extension occurs, you'll clobber all of the values in the MSB + // half of the resulting integer. + // + // Technically, by C/C++ integer promotion rules, we only need one of the + // uint32_t casts to the result type, but we put in both for explicitness's + // sake. + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); + return std::hash{}(bits); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/DeviceType.h b/backends/aoti/slim/c10/core/DeviceType.h new file mode 100644 index 00000000000..f2631a48f2d --- /dev/null +++ b/backends/aoti/slim/c10/core/DeviceType.h @@ -0,0 +1,134 @@ +#pragma once + +// Copied from c10/core/DeviceType.h with some modifications: +// * enum values are kept the same as c10 and guarded by device_type_test +// * Make the implementaion header-only +// * Simplify some implementation +// * Disable PrivateUse1 name registration + +#include +#include +#include +#include +#include +#include + +#include + +namespace standalone::c10 { +enum class DeviceType : int8_t { + CPU = 0, + CUDA = 1, // CUDA. + MKLDNN = 2, // Reserved for explicit MKLDNN + OPENGL = 3, // OpenGL + OPENCL = 4, // OpenCL + IDEEP = 5, // IDEEP. + HIP = 6, // AMD HIP + FPGA = 7, // FPGA + MAIA = 8, // ONNX Runtime / Microsoft + XLA = 9, // XLA / TPU + Vulkan = 10, // Vulkan + Metal = 11, // Metal + XPU = 12, // XPU + MPS = 13, // MPS + Meta = 14, // Meta (tensors with no data) + HPU = 15, // HPU / HABANA + VE = 16, // SX-Aurora / NEC + Lazy = 17, // Lazy Tensors + IPU = 18, // Graphcore IPU + MTIA = 19, // Meta training and inference devices + PrivateUse1 = 20, // PrivateUse1 device + // NB: If you add more devices: + // - Change the implementations of DeviceTypeName and isValidDeviceType + // - Change the number below + COMPILE_TIME_MAX_DEVICE_TYPES = 21, +}; + +constexpr DeviceType kCPU = DeviceType::CPU; +constexpr DeviceType kCUDA = DeviceType::CUDA; +constexpr DeviceType kMKLDNN = DeviceType::MKLDNN; +constexpr DeviceType kOPENGL = DeviceType::OPENGL; +constexpr DeviceType kOPENCL = DeviceType::OPENCL; +constexpr DeviceType kIDEEP = DeviceType::IDEEP; +constexpr DeviceType kHIP = DeviceType::HIP; +constexpr DeviceType kFPGA = DeviceType::FPGA; +constexpr DeviceType kMAIA = DeviceType::MAIA; +constexpr DeviceType kXLA = DeviceType::XLA; +constexpr DeviceType kVulkan = DeviceType::Vulkan; +constexpr DeviceType kMetal = DeviceType::Metal; +constexpr DeviceType kXPU = DeviceType::XPU; +constexpr DeviceType kMPS = DeviceType::MPS; +constexpr DeviceType kMeta = DeviceType::Meta; +constexpr DeviceType kHPU = DeviceType::HPU; +constexpr DeviceType kVE = DeviceType::VE; +constexpr DeviceType kLazy = DeviceType::Lazy; +constexpr DeviceType kIPU = DeviceType::IPU; +constexpr DeviceType kMTIA = DeviceType::MTIA; +constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1; + +// define explicit int constant +constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); + +static_assert( + COMPILE_TIME_MAX_DEVICE_TYPES <= 21, + "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " + "for this constant to reflect the actual number of DeviceTypes we support " + "in PyTorch; it's important that this number is not too large as we " + "use this to allocate stack arrays in some places in our code. If you " + "are indeed just adding the 20th device type, feel free to change " + "the check to 32; but if you are adding some sort of extensible device " + "types registration, please be aware that you are affecting code that " + "this number is small. Try auditing uses of this constant."); + +// Doesn't support PrivateUse1 name registration in standalone +inline std::string get_privateuse1_backend(bool lower_case = true) { + return lower_case ? "privateuse1" : "PrivateUse1"; +} + +inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) { + static const std::string device_names[] = { + "CPU", "CUDA", "MKLDNN", "OPENGL", "OPENCL", "IDEEP", "HIP", + "FPGA", "MAIA", "XLA", "VULKAN", "METAL", "XPU", "MPS", + "META", "HPU", "VE", "LAZY", "IPU", "MTIA"}; + + int idx = static_cast(d); + if (idx < 0 || idx >= COMPILE_TIME_MAX_DEVICE_TYPES) { + STANDALONE_CHECK(false, "Unknown device: ", static_cast(d)); + } + if (d == DeviceType::PrivateUse1) { + return get_privateuse1_backend(lower_case); + } + std::string name = device_names[idx]; + if (lower_case) { + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + } + return name; +} + +// NB: Per the C++ standard (e.g., +// https://stackoverflow.com/questions/18195312/what-happens-if-you-static-cast-invalid-value-to-enum-class) +// as long as you cast from the same underlying type, it is always valid to cast +// into an enum class (even if the value would be invalid by the enum.) Thus, +// the caller is allowed to cast a possibly invalid int16_t to DeviceType and +// then pass it to this function. (I considered making this function take an +// int16_t directly, but that just seemed weird.) +inline bool isValidDeviceType(DeviceType d) { + int idx = static_cast(d); + return idx >= 0 && idx < COMPILE_TIME_MAX_DEVICE_TYPES; +} + +inline std::ostream& operator<<(std::ostream& stream, DeviceType type) { + stream << DeviceTypeName(type, /* lower case */ true); + return stream; +} +} // namespace standalone::c10 + +namespace std { +template <> +struct hash { + std::size_t operator()(standalone::c10::DeviceType k) const { + return std::hash()(static_cast(k)); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/Layout.h b/backends/aoti/slim/c10/core/Layout.h new file mode 100644 index 00000000000..79230f23bb7 --- /dev/null +++ b/backends/aoti/slim/c10/core/Layout.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include +#include + +namespace standalone::c10 { +enum class Layout : int8_t { + Strided, + Sparse, + SparseCsr, + Mkldnn, + SparseCsc, + SparseBsr, + SparseBsc, + Jagged, + NumOptions +}; + +constexpr auto kStrided = Layout::Strided; +constexpr auto kSparse = Layout::Sparse; +constexpr auto kSparseCsr = Layout::SparseCsr; +constexpr auto kMkldnn = Layout::Mkldnn; +constexpr auto kSparseCsc = Layout::SparseCsc; +constexpr auto kSparseBsr = Layout::SparseBsr; +constexpr auto kSparseBsc = Layout::SparseBsc; +constexpr auto kJagged = Layout::Jagged; + +inline std::ostream& operator<<(std::ostream& stream, c10::Layout layout) { + switch (layout) { + case c10::kStrided: + return stream << "Strided"; + case c10::kSparse: + return stream << "Sparse"; + case c10::kSparseCsr: + return stream << "SparseCsr"; + case c10::kSparseCsc: + return stream << "SparseCsc"; + case c10::kSparseBsr: + return stream << "SparseBsr"; + case c10::kSparseBsc: + return stream << "SparseBsc"; + case c10::kMkldnn: + return stream << "Mkldnn"; + case c10::kJagged: + return stream << "Jagged"; + default: + STANDALONE_CHECK(false, "Unknown layout"); + } +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/MemoryFormat.h b/backends/aoti/slim/c10/core/MemoryFormat.h similarity index 68% rename from backends/aoti/slim/c10/MemoryFormat.h rename to backends/aoti/slim/c10/core/MemoryFormat.h index e5c155ce58e..756caf64f26 100644 --- a/backends/aoti/slim/c10/MemoryFormat.h +++ b/backends/aoti/slim/c10/core/MemoryFormat.h @@ -1,16 +1,7 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - - #pragma once -#include -#include +#include +#include #include #include @@ -34,11 +25,7 @@ // Regardless of input tensors format, the output should be in channels_last // format. -namespace c10 { - -using ::executorch::runtime::ArrayRef; -using ::executorch::runtime::IntArrayRef; - +namespace standalone::c10 { enum class MemoryFormat : int8_t { Contiguous, Preserve, @@ -50,7 +37,8 @@ enum class MemoryFormat : int8_t { // If you are seeing this, it means that this call site was not checked if // the memory format could be preserved, and it was switched to old default // behaviour of contiguous -#define LEGACY_CONTIGUOUS_MEMORY_FORMAT ::c10::get_contiguous_memory_format() +#define LEGACY_CONTIGUOUS_MEMORY_FORMAT \ + ::standalone::c10::get_contiguous_memory_format() inline MemoryFormat get_contiguous_memory_format() { return MemoryFormat::Contiguous; @@ -69,10 +57,7 @@ inline std::ostream& operator<<( case MemoryFormat::ChannelsLast3d: return stream << "ChannelsLast3d"; default: - ET_CHECK_MSG( - false, - "Unknown memory format %d", - static_cast(memory_format)); + STANDALONE_CHECK(false, "Unknown memory format ", memory_format); } } @@ -94,10 +79,8 @@ inline std::vector get_channels_last_strides_2d(ArrayRef sizes) { strides[1] = strides[2] * sizes[2]; return strides; default: - ET_CHECK_MSG( - false, - "ChannelsLast2d doesn't support size %zu", - static_cast(sizes.size())); + STANDALONE_INTERNAL_ASSERT( + false, "ChannelsLast2d doesn't support size ", sizes.size()); } } @@ -123,10 +106,8 @@ std::vector get_channels_last_strides_3d(ArrayRef sizes) { strides[1] = strides[2] * sizes[2]; return strides; default: - ET_CHECK_MSG( - false, - "ChannelsLast3d doesn't support size %zu", - static_cast(sizes.size())); + STANDALONE_INTERNAL_ASSERT( + false, "ChannelsLast3d doesn't support size ", sizes.size()); } } @@ -213,6 +194,56 @@ inline bool is_channels_last_strides_3d_s5( return true; } +// Note [Ambiguous is_channels_last_strides_xd] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// The flaw of carrying memory_format implicitly through strides is very hard +// to WAR properly. issue #24090 +// Without the history of permutation, we can't infer the memory_format of a +// tensor from the snapshot of its size & stride +// e.g. +// +// 1. We can NOT specify the memory_format of N111 tensor through strides in a +// meaningful way; +// +// 2. Two path that ended up with identical size/stride +// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W] +// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C] +// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer +// the memory_format of the original tensor. +// +// Due to the limitations, our temporary WAR `is_channels_last_strides` does the +// best effort to infer whether the original memory_format of a tensor is +// MemoryFormat::ChannelsLast. The two objectives of this function (ordered +// by their importance): +// 1. Ensure that normal shape manipulation does not accidentally change the +// MemoryFormat of an existing tensor. +// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors; +// +// The function does so via checking strides of the tensor, including strides of +// size-1 dimensions. Although conventionally PyTorch implies no restriction on +// trivial stride (stride for size-1 dimension). +// +// Note that this approach is a compromise. We did not solve the problem +// completely. Many cases we will not be able to infer the correct memory +// format. +// The implementation of `is_channels_last_strides` is to serve the objectives: +// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental +// conversion); Best effort to maintain the ChannelsLast flag. +// +// Due to the fact that this is not a bulletproof solution, through testing +// (aten/src/ATen/test/memory_format_test.cpp) +// a. we ensure that the common tasks are supported; +// a. we identify corner cases where the implementation compromises on. +// +// By the time accumulated permutation is enabled to replace implicit +// memory_format through strides, we should be updating our tests and fix the +// issues in our tests. +// +// We use Channels Last 2d as an example above. +// This is a general problem for all the is_channels_last_strides_xd +// implementation. Please check the helper functions +// (is_channels_last_strides_*d_s*) for more details. + template inline bool is_channels_last_strides_2d( const ArrayRef sizes, @@ -257,4 +288,4 @@ inline bool is_channels_last_strides_3d( return is_channels_last_strides_3d(sizes, strides); } -} // namespace c10 +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/Scalar.h b/backends/aoti/slim/c10/core/Scalar.h new file mode 100644 index 00000000000..1c61ecb4704 --- /dev/null +++ b/backends/aoti/slim/c10/core/Scalar.h @@ -0,0 +1,360 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// Copy-pasted from c10/core/Scalar.h, but dropping SymScalar support + +namespace standalone::c10 { + +/** + * Scalar represents a 0-dimensional tensor which contains a single element. + * Unlike a tensor, numeric literals (in C++) are implicitly convertible to + * Scalar (which is why, for example, we provide both add(Tensor) and + * add(Scalar) overloads for many operations). It may also be used in + * circumstances where you statically know a tensor is 0-dim and single size, + * but don't know its type. + */ +class Scalar { + public: + Scalar() : Scalar(int64_t(0)) {} + +#define DEFINE_IMPLICIT_CTOR(type, name) \ + Scalar(type vv) : Scalar(vv, true) {} + + AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) + AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) + AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR) + + // Helper constructors to allow Scalar creation from long and long long types + // As std::is_same_v is false(except Android), one needs to + // provide a constructor from either long or long long in addition to one from + // int64_t +#if defined(__APPLE__) || defined(__MACOSX) + static_assert( + std::is_same_v, + "int64_t is the same as long long on MacOS"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(_MSC_VER) + static_assert( + std::is_same_v, + "int64_t is the same as long long on Windows"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(__linux__) && !defined(__ANDROID__) + static_assert( + sizeof(void*) != 8 || std::is_same_v, + "int64_t is the same as long on 64 bit Linux"); +#if LONG_MAX != INT_MAX + Scalar(long long vv) : Scalar(vv, true) {} +#endif /* not 32-bit system */ +#endif + + Scalar(uint16_t vv) : Scalar(vv, true) {} + Scalar(uint32_t vv) : Scalar(vv, true) {} + Scalar(uint64_t vv) { + if (vv > static_cast(INT64_MAX)) { + tag = Tag::HAS_u; + v.u = vv; + } else { + tag = Tag::HAS_i; + // NB: no need to use convert, we've already tested convertibility + v.i = static_cast(vv); + } + } + +#undef DEFINE_IMPLICIT_CTOR + + // Value* is both implicitly convertible to SymbolicVariable and bool which + // causes ambiguity error. Specialized constructor for bool resolves this + // problem. + template < + typename T, + typename std::enable_if_t, bool>* = nullptr> + Scalar(T vv) : tag(Tag::HAS_b) { + v.i = convert(vv); + } + +#define DEFINE_ACCESSOR(type, name) \ + type to##name() const { \ + if (Tag::HAS_d == tag) { \ + return checked_convert(v.d, #type); \ + } else if (Tag::HAS_z == tag) { \ + return checked_convert>( \ + v.z, #type); \ + } \ + if (Tag::HAS_b == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_i == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_u == tag) { \ + return checked_convert(v.u, #type); \ + } \ + STANDALONE_CHECK(false) \ + } + + // TODO: Support ComplexHalf accessor + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) + DEFINE_ACCESSOR(uint16_t, UInt16) + DEFINE_ACCESSOR(uint32_t, UInt32) + DEFINE_ACCESSOR(uint64_t, UInt64) + +#undef DEFINE_ACCESSOR + + // also support scalar.to(); + // Deleted for unsupported types, but specialized below for supported types + template + T to() const = delete; + + // audit uses of data_ptr + const void* data_ptr() const { + return static_cast(&v); + } + + bool isFloatingPoint() const { + return Tag::HAS_d == tag; + } + + bool isIntegral(bool includeBool) const { + return Tag::HAS_i == tag || Tag::HAS_u == tag || + (includeBool && isBoolean()); + } + + bool isComplex() const { + return Tag::HAS_z == tag; + } + bool isBoolean() const { + return Tag::HAS_b == tag; + } + + STANDALONE_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept { + if (&other == this) { + return *this; + } + + moveFrom(std::move(other)); + return *this; + } + + STANDALONE_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { + if (&other == this) { + return *this; + } + + *this = Scalar(other); + return *this; + } + + Scalar operator-() const { + STANDALONE_CHECK( + !isBoolean(), + "torch boolean negative, the `-` operator, is not supported."); + if (isFloatingPoint()) { + return Scalar(-v.d); + } else if (isComplex()) { + return Scalar(-v.z); + } else if (isIntegral(false)) { + return Scalar(-v.i); + } + STANDALONE_INTERNAL_ASSERT( + false, "unknown ivalue tag ", static_cast(tag)); + } + + Scalar conj() const { + if (isComplex()) { + return Scalar(std::conj(v.z)); + } else { + return *this; + } + } + + Scalar log() const { + if (isComplex()) { + return std::log(v.z); + } else if (isFloatingPoint()) { + return std::log(v.d); + } else if (isIntegral(false)) { + return std::log(v.i); + } + STANDALONE_INTERNAL_ASSERT( + false, "unknown ivalue tag ", static_cast(tag)); + } + + template < + typename T, + typename std::enable_if_t::value, int> = + 0> + bool equal(T num) const { + if (isComplex()) { + auto val = v.z; + return (val.real() == num) && (val.imag() == T()); + } else if (isFloatingPoint()) { + return toDouble() == num; + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num; + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num; + } + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + return false; + } else { + STANDALONE_INTERNAL_ASSERT(false); + } + } + + template < + typename T, + typename std::enable_if_t::value, int> = 0> + bool equal(T num) const { + if (isComplex()) { + return v.z == num; + } else if (isFloatingPoint()) { + return (toDouble() == num.real()) && (num.imag() == T()); + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num.real() && num.imag() == T(); + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num.real() && num.imag() == T(); + } + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + return false; + } else { + STANDALONE_INTERNAL_ASSERT(false); + } + } + + bool equal(bool num) const { + if (isBoolean()) { + return static_cast(v.i) == num; + } else { + return false; + } + } + + standalone::c10::ScalarType type() const { + if (isComplex()) { + return standalone::c10::ScalarType::ComplexDouble; + } else if (isFloatingPoint()) { + return standalone::c10::ScalarType::Double; + } else if (isIntegral(/*includeBool=*/false)) { + // Represent all integers as long, UNLESS it is unsigned and therefore + // unrepresentable as long + if (Tag::HAS_u == tag) { + return standalone::c10::ScalarType::UInt64; + } + return standalone::c10::ScalarType::Long; + } else if (isBoolean()) { + return standalone::c10::ScalarType::Bool; + } else { + throw std::runtime_error("Unknown scalar type."); + } + } + + Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { + moveFrom(std::move(rhs)); + } + + Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {} + + // We can't set v in the initializer list using the + // syntax v{ .member = ... } because it doesn't work on MSVC + private: + enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b }; + + // Note [Meaning of HAS_u] + // ~~~~~~~~~~~~~~~~~~~~~~~ + // HAS_u is a bit special. On its face, it just means that we + // are holding an unsigned integer. However, we generally don't + // distinguish between different bit sizes in Scalar (e.g., we represent + // float as double), instead, it represents a mathematical notion + // of some quantity (integral versus floating point). So actually, + // HAS_u is used solely to represent unsigned integers that could + // not be represented as a signed integer. That means only uint64_t + // potentially can get this tag; smaller types like uint8_t fits into a + // regular int and so for BC reasons we keep as an int. + + // NB: assumes that self has already been cleared + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + STANDALONE_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { + v = rhs.v; + tag = rhs.tag; + } + + Tag tag; + + union v_t { + double d{}; + int64_t i; + // See Note [Meaning of HAS_u] + uint64_t u; + standalone::c10::complex z; + // NOLINTNEXTLINE(modernize-use-equals-default) + v_t() {} // default constructor + } v; + + template < + typename T, + typename std::enable_if_t< + std::is_integral_v && !std::is_same_v, + bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_i) { + v.i = convert(vv); + } + + template < + typename T, + typename std::enable_if_t< + !std::is_integral_v && !standalone::c10::is_complex::value, + bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_d) { + v.d = convert(vv); + } + + template < + typename T, + typename std::enable_if_t::value, bool>* = + nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_z) { + v.z = convert(vv); + } +}; + +// define the scalar.to() specializations +#define DEFINE_TO(T, name) \ + template <> \ + inline T Scalar::to() const { \ + return to##name(); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) +DEFINE_TO(uint16_t, UInt16) +DEFINE_TO(uint32_t, UInt32) +DEFINE_TO(uint64_t, UInt64) +#undef DEFINE_TO + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h new file mode 100644 index 00000000000..6daeaad5f2c --- /dev/null +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -0,0 +1,735 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace standalone::c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + +// For the macros below: +// +// For users: If you want to macro some code for all non-QInt scalar types +// (i.e. types with complete information, you probably want one of the +// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are +// designed to behave similarly to the Dispatch macros with the same name. +// +// For adding a new dtype: In the beginning, we had an idea that there was a +// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to +// iterate over them. But over the years we added weird types which couldn't +// be handled uniformly everywhere and so in the end we ended up with some +// mish-mosh of some helper macros, but mostly use sites making a call about +// what dtypes they can or can't support. So if you want to add a new dtype, +// the preferred resolution is to find a dtype similar to what you want, +// grep for it and edit all the sites you find this way. If you need to add +// a completely new kind of dtype, you're going to have to laboriously audit +// all of the sites everywhere to figure out how it should work. Consulting +// some old PRs where we added new dtypes (check history of this file) can +// help give you an idea where to start. + +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(standalone::c10::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(standalone::c10::complex, ComplexHalf) /* 8 */ \ + _(standalone::c10::complex, ComplexFloat) /* 9 */ \ + _(standalone::c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(standalone::c10::qint8, QInt8) /* 12 */ \ + _(standalone::c10::quint8, QUInt8) /* 13 */ \ + _(standalone::c10::qint32, QInt32) /* 14 */ \ + _(standalone::c10::BFloat16, BFloat16) /* 15 */ \ + _(standalone::c10::quint4x2, QUInt4x2) /* 16 */ \ + _(standalone::c10::quint2x4, QUInt2x4) /* 17 */ \ + _(standalone::c10::bits1x8, Bits1x8) /* 18 */ \ + _(standalone::c10::bits2x4, Bits2x4) /* 19 */ \ + _(standalone::c10::bits4x2, Bits4x2) /* 20 */ \ + _(standalone::c10::bits8, Bits8) /* 21 */ \ + _(standalone::c10::bits16, Bits16) /* 22 */ \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(standalone::c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(standalone::c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ \ + _(standalone::c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ + _(standalone::c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ + _(standalone::c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ + _(standalone::c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ + _(standalone::c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ + _(standalone::c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ + _(standalone::c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ + _(standalone::c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ + _(standalone::c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ + _(standalone::c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ + _(standalone::c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ + _(standalone::c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ + _(standalone::c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ + _(standalone::c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ + _(standalone::c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ + _(standalone::c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ + +// If you want to support ComplexHalf for real, add ComplexHalf +// into this macro (and change the name). But beware: convert() +// doesn't work for all the conversions you need... +// +// TODO: To add unsigned int types here, we must define accumulate type. +// But uint8 currently accumulates into int64, so we would have to make +// an inconsistent choice for the larger types. Difficult. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(standalone::c10::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(standalone::c10::complex, ComplexFloat) \ + _(standalone::c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(standalone::c10::BFloat16, BFloat16) \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) + +// This macro controls many of our C++ APIs, including constructors +// for Scalar as well as the data() and item() accessors on Tensor +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(standalone::c10::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(standalone::c10::complex, ComplexHalf) \ + _(standalone::c10::complex, ComplexFloat) \ + _(standalone::c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(standalone::c10::BFloat16, BFloat16) \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) \ + _(standalone::c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(standalone::c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(standalone::c10::Float8_e8m0fnu, Float8_e8m0fnu) + +enum class ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ + Undefined, + NumOptions +}; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); + +namespace impl { + +// These are used to map ScalarTypes to C++ types. + +template +struct ScalarTypeToCPPType; + +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ + \ + /* This is a workaround for the CUDA bug which prevents */ \ + /* ::detail::ScalarTypeToCType::type being used directly due to */ \ + /* ambiguous reference which can't to be resolved. For some reason it */ \ + /* can't pick between standalone::c10::detail and \ + * standalone::c10::cuda::detail. */ \ + /* For repro example, please see: */ \ + /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ + /* TODO: remove once the bug is fixed. */ \ + static type t; \ + }; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) + +#undef SPECIALIZE_ScalarTypeToCPPType + +template +using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType::type; + +} // namespace impl + +template +struct CppTypeToScalarType; + +#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ + template <> \ + struct CppTypeToScalarType \ + : std::integral_constant< \ + standalone::c10::ScalarType, \ + standalone::c10::ScalarType::scalar_type> {}; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) + +#undef SPECIALIZE_CppTypeToScalarType + +// NB: despite its generic sounding name, the macros that don't take _AND +// are mostly only used by tensorexpr +#define AT_FORALL_INT_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) + +#define AT_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) + +// These macros are often controlling how many template instantiations we +// create for kernels. It is typically inappropriate to add new dtypes here, +// instead, new types should be added to use sites on a case-by-case basis. +// We generally are not accepting new dtypes due to binary size concerns. + +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE>::t), \ + SCALARTYPE) + +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) + +#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) + +#define AT_FORALL_SCALAR_TYPES_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE4>::t), \ + SCALARTYPE4) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE5>::t), \ + SCALARTYPE5) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE6>::t), \ + SCALARTYPE6) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE7>::t), \ + SCALARTYPE7) + +#define AT_FORALL_QINT_TYPES(_) \ + _(standalone::c10::qint8, QInt8) \ + _(standalone::c10::quint8, QUInt8) \ + _(standalone::c10::qint32, QInt32) \ + _(standalone::c10::quint4x2, QUInt4x2) \ + _(standalone::c10::quint2x4, QUInt2x4) + +#define AT_FORALL_FLOAT8_TYPES(_) \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) \ + _(standalone::c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(standalone::c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(standalone::c10::Float8_e8m0fnu, Float8_e8m0fnu) + +#define AT_FORALL_COMPLEX_TYPES(_) \ + _(standalone::c10::complex, ComplexFloat) \ + _(standalone::c10::complex, ComplexDouble) + +#define DEFINE_CONSTANT(_, name) \ + constexpr ScalarType k##name = ScalarType::name; + +// NOLINTNEXTLINE(clang-diagnostic-unused-const-variable) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) +#undef DEFINE_CONSTANT + +inline const char* toString(ScalarType t) { +#define DEFINE_CASE(_, name) \ + case ScalarType::name: \ + return #name; + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + default: + return "UNKNOWN_SCALAR"; + } +#undef DEFINE_CASE +} + +inline size_t elementSize(ScalarType t) { +#define CASE_ELEMENTSIZE_CASE(ctype, name) \ + case ScalarType::name: \ + return sizeof(ctype); + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + default: + STANDALONE_CHECK(false, "Unknown ScalarType"); + } +#undef CASE_ELEMENTSIZE_CASE +} + +inline bool isIntegralType(ScalarType t, bool includeBool) { + bool isIntegral = + (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || + t == ScalarType::Long || t == ScalarType::Short || + t == ScalarType::UInt16 || t == ScalarType::UInt32 || + t == ScalarType::UInt64); + + return isIntegral || (includeBool && t == ScalarType::Bool); +} + +inline bool isFloat8Type(ScalarType t) { + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || + t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz || + t == ScalarType::Float8_e8m0fnu; +} + +inline bool isReducedFloatingType(ScalarType t) { + return t == ScalarType::Half || t == ScalarType::BFloat16 || + isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2; +} + +inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Double || t == ScalarType::Float || + isReducedFloatingType(t); +} + +inline bool isComplexType(ScalarType t) { + return ( + t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); +} + +inline bool isQIntType(ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || + t == ScalarType::QUInt2x4; +} + +inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + +inline bool isBarebonesUnsignedType(ScalarType t) { + return t == ScalarType::UInt1 || t == ScalarType::UInt2 || + t == ScalarType::UInt3 || t == ScalarType::UInt4 || + t == ScalarType::UInt5 || t == ScalarType::UInt6 || + t == ScalarType::UInt7 || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64; +} + +inline ScalarType toQIntType(ScalarType t) { + switch (t) { + case ScalarType::Byte: + return ScalarType::QUInt8; + case ScalarType::Char: + return ScalarType::QInt8; + case ScalarType::Int: + return ScalarType::QInt32; + default: + return t; + } +} + +inline ScalarType toUnderlying(ScalarType t) { + switch (t) { + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + [[fallthrough]]; + case ScalarType::QUInt2x4: + return ScalarType::Byte; + case ScalarType::QInt8: + return ScalarType::Char; + case ScalarType::QInt32: + return ScalarType::Int; + default: + return t; + } +} + +inline bool isSignedType(ScalarType t) { +#define CASE_ISSIGNED(name) \ + case ScalarType::name: \ + return std::numeric_limits<::standalone::c10::impl::ScalarTypeToCPPTypeT< \ + ScalarType::name>>::is_signed; + + // TODO(#146647): If we expect to have numeric_limits for everything, + // let's just have a big macro for the whole thing. + // If we're hardcoding it, let's just use the macro and a "true"/"false" + // below? + switch (t) { + case ScalarType::QInt8: + case ScalarType::QUInt8: + case ScalarType::QInt32: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + STANDALONE_CHECK(false, "isSignedType not supported for quantized types"); + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + STANDALONE_CHECK(false, "Bits types are undefined"); + CASE_ISSIGNED(UInt16); + CASE_ISSIGNED(UInt32); + CASE_ISSIGNED(UInt64); + CASE_ISSIGNED(BFloat16); + CASE_ISSIGNED(Float8_e5m2); + CASE_ISSIGNED(Float8_e5m2fnuz); + CASE_ISSIGNED(Float8_e4m3fn); + CASE_ISSIGNED(Float8_e4m3fnuz); + CASE_ISSIGNED(Float8_e8m0fnu); + CASE_ISSIGNED(Byte); + CASE_ISSIGNED(Char); + CASE_ISSIGNED(Short); + CASE_ISSIGNED(Int); + CASE_ISSIGNED(Long); + CASE_ISSIGNED(Half); + CASE_ISSIGNED(Float); + CASE_ISSIGNED(Double); + CASE_ISSIGNED(ComplexHalf); + CASE_ISSIGNED(ComplexFloat); + CASE_ISSIGNED(ComplexDouble); + CASE_ISSIGNED(Bool); + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + case ScalarType::Float4_e2m1fn_x2: + return true; + case ScalarType::UInt1: + case ScalarType::UInt2: + case ScalarType::UInt3: + case ScalarType::UInt4: + case ScalarType::UInt5: + case ScalarType::UInt6: + case ScalarType::UInt7: + return false; + case ScalarType::Undefined: + case ScalarType::NumOptions: + break; + // Do not add default here, but rather define behavior of every new entry + // here. `-Wswitch-enum` would raise a warning in those cases. + } + STANDALONE_CHECK(false, "Unknown ScalarType ", t); +#undef CASE_ISSIGNED +} + +inline bool isUnderlying(ScalarType type, ScalarType qtype) { + return type == toUnderlying(qtype); +} + +inline ScalarType toRealValueType(ScalarType t) { + switch (t) { + case ScalarType::ComplexHalf: + return ScalarType::Half; + case ScalarType::ComplexFloat: + return ScalarType::Float; + case ScalarType::ComplexDouble: + return ScalarType::Double; + default: + return t; + } +} + +inline ScalarType toComplexType(ScalarType t) { + switch (t) { + case ScalarType::BFloat16: + // BFloat16 has range equivalent to Float, + // so we map it to ComplexFloat. + return ScalarType::ComplexFloat; + case ScalarType::Half: + return ScalarType::ComplexHalf; + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + case ScalarType::ComplexHalf: + return ScalarType::ComplexHalf; + case ScalarType::ComplexFloat: + return ScalarType::ComplexFloat; + case ScalarType::ComplexDouble: + return ScalarType::ComplexDouble; + default: + STANDALONE_CHECK(false, "Unknown Complex ScalarType for ", t); + } +} + +// see tensor_attributes.rst for detailed explanation and examples +// of casting rules. +inline bool canCast(const ScalarType from, const ScalarType to) { + // We disallow complex -> non complex, e.g., float_tensor *= complex is + // disallowed. + if (isComplexType(from) && !isComplexType(to)) { + return false; + } + // We disallow float -> integral, e.g., int_tensor *= float is disallowed. + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + + // Treat bool as a distinct "category," to be consistent with type promotion + // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same + // category as `bool_tensor`, we would not promote. Differing categories + // implies `bool_tensor += 5` is disallowed. + // + // NB: numpy distinguishes "unsigned" as a category to get the desired + // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: + // * We don't want the performance hit of checking the runtime sign of + // Scalars. + // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + +namespace detail { +constexpr auto u1 = ScalarType::Byte; +constexpr auto i1 = ScalarType::Char; +constexpr auto i2 = ScalarType::Short; +constexpr auto i4 = ScalarType::Int; +constexpr auto i8 = ScalarType::Long; +constexpr auto f2 = ScalarType::Half; +constexpr auto f4 = ScalarType::Float; +constexpr auto f8 = ScalarType::Double; +constexpr auto c2 = ScalarType::ComplexHalf; +constexpr auto c4 = ScalarType::ComplexFloat; +constexpr auto c8 = ScalarType::ComplexDouble; +constexpr auto b1 = ScalarType::Bool; +constexpr auto bf = ScalarType::BFloat16; +constexpr auto ud = ScalarType::Undefined; + +constexpr auto index2dtype = array_of( + u1, + i1, + i2, + i4, + i8, + f2, + f4, + f8, + c2, + c4, + c8, + b1, + bf); + +constexpr std::array(ScalarType::NumOptions)> +calculate_dtype2index() { + std::array(ScalarType::NumOptions)> inverse = {}; + for (int64_t i = 0; i < static_cast(ScalarType::NumOptions); i++) { + inverse[i] = -1; + } + for (int64_t i = 0; i < static_cast(index2dtype.size()); i++) { + inverse[static_cast(index2dtype[i])] = i; + } + return inverse; +} + +constexpr auto dtype2index = calculate_dtype2index(); +} // namespace detail + +inline ScalarType promoteTypes(ScalarType a, ScalarType b) { + using namespace detail; + + // This is generated according to NumPy's promote_types + if (a == ud || b == ud) { + return ScalarType::Undefined; + } + + // If the two types are equal, return that type + if (a == b) { + return a; + } + + // Handle identically equal types + if (isQIntType(a) || isQIntType(b)) { + STANDALONE_CHECK( + false, + "promoteTypes with quantized numbers is not handled yet; figure out " + "what the correct rules should be, offending types: ", + toString(a), + " ", + toString(b)); + } + + if (isBitsType(a) || isBitsType(b)) { + return ScalarType::Undefined; + } + + if (isFloat8Type(a) || isFloat8Type(b)) { + STANDALONE_CHECK( + false, + "Promotion for Float8 Types is not supported, attempted to promote ", + toString(a), + " and ", + toString(b)); + } + + if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) { + // There are two problems with promotion here: + // + // - Our promotion rule for uint8 is inconsistent with Numpy; Numpy + // promotes to uint64, but since we never had uint64 for the longest + // time, we promote to int64. Changing this is BC-breaking + // + // - We must not promote uint64 to int64 because this will overflow. + // + // It'll be a bit of work to fix it, so we're punting on it for now. + // However, float promotion is fine, so we handle that. + if (isFloatingType(a)) { + return a; + } + if (isFloatingType(b)) { + return b; + } + STANDALONE_CHECK( + false, + "Promotion for uint16, uint32, uint64 types is not supported, " + "attempted to promote ", + toString(a), + " and ", + toString(b)); + } + auto ix_a = dtype2index[static_cast(a)]; + STANDALONE_INTERNAL_ASSERT(ix_a != -1); + auto ix_b = dtype2index[static_cast(b)]; + STANDALONE_INTERNAL_ASSERT(ix_b != -1); + + // This table axes must be consistent with index2dtype + // clang-format off + static constexpr std:: + array, index2dtype.size()> + _promoteTypesLookup = {{ + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, + }}; + // clang-format on + return _promoteTypesLookup[ix_a][ix_b]; +} + +inline std::ostream& operator<<( + std::ostream& stream, + standalone::c10::ScalarType scalar_type) { + return stream << toString(scalar_type); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/SizesAndStrides.h b/backends/aoti/slim/c10/core/SizesAndStrides.h similarity index 79% rename from backends/aoti/slim/c10/SizesAndStrides.h rename to backends/aoti/slim/c10/core/SizesAndStrides.h index 028097370e4..aef0ddab171 100644 --- a/backends/aoti/slim/c10/SizesAndStrides.h +++ b/backends/aoti/slim/c10/core/SizesAndStrides.h @@ -1,28 +1,16 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - #pragma once -#include -#include -#include - #include #include #include #include -#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 +#include +#include -namespace c10 { +#define STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 -using ::executorch::runtime::ArrayRef; -using ::executorch::runtime::IntArrayRef; +namespace standalone::c10 { // Packed container for TensorImpl sizes and strides. // This design improves on the previous approach of using a pair of @@ -48,14 +36,14 @@ class SizesAndStrides { } ~SizesAndStrides() { - if (C10_UNLIKELY(!isInline())) { + if (STANDALONE_UNLIKELY(!isInline())) { // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) free(outOfLineStorage_); } } SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { - if (C10_LIKELY(rhs.isInline())) { + if (STANDALONE_LIKELY(rhs.isInline())) { copyDataInline(rhs); } else { allocateOutOfLineStorage(size_); @@ -70,9 +58,7 @@ class SizesAndStrides { return !( isInline() ? std::memcmp( - inlineStorage_, - other.inlineStorage_, - sizeof(inlineStorage_)) + inlineStorage_, other.inlineStorage_, sizeof(inlineStorage_)) : std::memcmp( outOfLineStorage_, other.outOfLineStorage_, @@ -83,8 +69,8 @@ class SizesAndStrides { if (this == &rhs) { return *this; } - if (C10_LIKELY(rhs.isInline())) { - if (C10_UNLIKELY(!isInline())) { + if (STANDALONE_LIKELY(rhs.isInline())) { + if (STANDALONE_UNLIKELY(!isInline())) { // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) free(outOfLineStorage_); } @@ -103,7 +89,7 @@ class SizesAndStrides { // Move from rhs. rhs.size() == 0 afterwards. SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { - if (C10_LIKELY(isInline())) { + if (STANDALONE_LIKELY(isInline())) { memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); } else { outOfLineStorage_ = rhs.outOfLineStorage_; @@ -118,8 +104,8 @@ class SizesAndStrides { if (this == &rhs) { return *this; } - if (C10_LIKELY(rhs.isInline())) { - if (C10_UNLIKELY(!isInline())) { + if (STANDALONE_LIKELY(rhs.isInline())) { + if (STANDALONE_UNLIKELY(!isInline())) { // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) free(outOfLineStorage_); } @@ -144,7 +130,7 @@ class SizesAndStrides { } const int64_t* sizes_data() const noexcept { - if (C10_LIKELY(isInline())) { + if (STANDALONE_LIKELY(isInline())) { return &inlineStorage_[0]; } else { return &outOfLineStorage_[0]; @@ -152,7 +138,7 @@ class SizesAndStrides { } int64_t* sizes_data() noexcept { - if (C10_LIKELY(isInline())) { + if (STANDALONE_LIKELY(isInline())) { return &inlineStorage_[0]; } else { return &outOfLineStorage_[0]; @@ -185,41 +171,37 @@ class SizesAndStrides { } void set_strides(IntArrayRef strides) { - ET_DCHECK_MSG( - strides.size() == size(), - "strides size %zu must match current size %zu", - strides.size(), - size()); + STANDALONE_INTERNAL_ASSERT(strides.size() == size()); std::copy(strides.begin(), strides.end(), strides_begin()); } const int64_t* strides_data() const noexcept { - if (C10_LIKELY(isInline())) { - return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; } else { return &outOfLineStorage_[size()]; } } int64_t* strides_data() noexcept { - if (C10_LIKELY(isInline())) { - return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; } else { return &outOfLineStorage_[size()]; } } strides_const_iterator strides_begin() const noexcept { - if (C10_LIKELY(isInline())) { - return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; } else { return &outOfLineStorage_[size()]; } } strides_iterator strides_begin() noexcept { - if (C10_LIKELY(isInline())) { - return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; } else { return &outOfLineStorage_[size()]; } @@ -280,13 +262,16 @@ class SizesAndStrides { if (newSize == oldSize) { return; } - if (C10_LIKELY( - newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (STANDALONE_LIKELY( + newSize <= STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE && + isInline())) { if (oldSize < newSize) { - const auto bytesToZero = (newSize - oldSize) * sizeof(inlineStorage_[0]); + const auto bytesToZero = + (newSize - oldSize) * sizeof(inlineStorage_[0]); memset(&inlineStorage_[oldSize], 0, bytesToZero); memset( - &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], + &inlineStorage_ + [STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 0, bytesToZero); } @@ -298,19 +283,21 @@ class SizesAndStrides { private: void resizeSlowPath(size_t newSize, size_t oldSize) { - if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { - ET_DCHECK_MSG( + if (newSize <= STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( !isInline(), "resizeSlowPath called when fast path should have been hit!"); int64_t* tempStorage = outOfLineStorage_; memcpy( &inlineStorage_[0], &tempStorage[0], - C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE * + sizeof(inlineStorage_[0])); memcpy( - &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE], &tempStorage[oldSize], - C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0])); + STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE * + sizeof(inlineStorage_[0])); // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ // HAS BEEN OVERWRITTEN! // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) @@ -322,7 +309,7 @@ class SizesAndStrides { int64_t* tempStorage = // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) static_cast(malloc(storageBytes(newSize))); - ET_CHECK_MSG( + STANDALONE_CHECK( tempStorage, "Could not allocate memory to change Tensor SizesAndStrides!"); const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); @@ -335,7 +322,7 @@ class SizesAndStrides { } memcpy( &tempStorage[newSize], - &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE], bytesToCopy); if (bytesToZero) { memset(&tempStorage[newSize + oldSize], 0, bytesToZero); @@ -370,11 +357,11 @@ class SizesAndStrides { } bool isInline() const noexcept { - return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + return size_ <= STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE; } void copyDataInline(const SizesAndStrides& rhs) { - ET_DCHECK_MSG(rhs.isInline(), "rhs must be inline"); + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); } @@ -385,17 +372,17 @@ class SizesAndStrides { void allocateOutOfLineStorage(size_t size) { // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) outOfLineStorage_ = static_cast(malloc(storageBytes(size))); - ET_CHECK_MSG( + STANDALONE_CHECK( outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); } void resizeOutOfLineStorage(size_t newSize) { - ET_DCHECK_MSG(!isInline(), "must not be inline"); + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); outOfLineStorage_ = static_cast( // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) realloc(outOfLineStorage_, storageBytes(newSize))); - ET_CHECK_MSG( + STANDALONE_CHECK( outOfLineStorage_, "Could not allocate memory for Tensor SizesAndStrides!"); } @@ -408,8 +395,8 @@ class SizesAndStrides { union { int64_t* outOfLineStorage_; // NOLINTNEXTLINE(*c-array*) - int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + int64_t inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; }; }; -} // namespace c10 +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/WrapDimMinimal.h b/backends/aoti/slim/c10/core/WrapDimMinimal.h new file mode 100644 index 00000000000..651421e6d89 --- /dev/null +++ b/backends/aoti/slim/c10/core/WrapDimMinimal.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include + +// Different from the original implementation in c10, we don't need +// to support SymInt here. +namespace standalone::c10 { +namespace detail { +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); +} + +template +T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { + // Inline the fast paths + if (STANDALONE_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { + // For SymInts, we want an explicit control flow to trigger a guard, so we + // may as well branch too. + if (dim < 0) { + return dim + dim_post_expr; + } + return dim; + } + // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) + return standalone::c10::detail::maybe_wrap_dim_slow( + std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +inline int64_t +maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar = true) { + return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); +} + +namespace detail { +// This template can only be specialized at int64_t and c10::SymInt; +// you'll get linker errors otherwise +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { + STANDALONE_CHECK( + dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr); + + if (dim_post_expr == 0) { + STANDALONE_CHECK( + wrap_scalar, + "Dimension specified as ", + dim, + " but tensor has no dimensions"); + return standalone::c10::maybe_wrap_dim( + std::move(dim), + /*dim_post_expr=*/1, + /*wrap_scalar=*/false); + } + + T min = dim_post_expr * -1; + T max = dim_post_expr - 1; + STANDALONE_CHECK( + min <= dim && dim <= max, + "Dimension out of range (expected to be in range of [", + min, + ", ", + max, + "], but got ", + dim, + ")"); + + STANDALONE_INTERNAL_ASSERT( + false, "should never reach here as dim should be out-of-bounds"); +} +} // namespace detail +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/cuda/Exception.h b/backends/aoti/slim/c10/cuda/Exception.h new file mode 100644 index 00000000000..bd972c1652d --- /dev/null +++ b/backends/aoti/slim/c10/cuda/Exception.h @@ -0,0 +1,29 @@ +#pragma once +#ifdef USE_CUDA + +#include +#include +#include + +#include +#include +#include + +#include + +#define STANDALONE_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + STANDALONE_CHECK(__err == cudaSuccess, cudaGetErrorString(__err)); \ + } while (0) + +#define STANDALONE_CUDA_CHECK_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (STANDALONE_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + STANDALONE_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif diff --git a/backends/aoti/slim/c10/macros/Macros.h b/backends/aoti/slim/c10/macros/Macros.h new file mode 100644 index 00000000000..aa8329263fe --- /dev/null +++ b/backends/aoti/slim/c10/macros/Macros.h @@ -0,0 +1,219 @@ +#pragma once + +#include + +// UBSan (Undefined Behavior Sanitizer) macros +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_pointer_overflow__ \ + __attribute__((no_sanitize("pointer-overflow"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_float_cast_overflow__ \ + __attribute__((no_sanitize("float-cast-overflow"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_pointer_overflow__ +#define __ubsan_ignore_function__ +#define __ubsan_ignore_float_cast_overflow__ +#endif + +// STANDALONE_LIKELY/STANDALONE_UNLIKELY +// +// These macros provide parentheses, so you can use these macros as: +// +// if STANDALONE_LIKELY(some_expr) { +// ... +// } +// +// NB: static_cast to boolean is mandatory in C++, because __builtin_expect +// takes a long argument, which means you may trigger the wrong conversion +// without it. +// +#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) +#define STANDALONE_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define STANDALONE_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#else +#define STANDALONE_LIKELY(expr) (expr) +#define STANDALONE_UNLIKELY(expr) (expr) +#endif + +// On nvcc, STANDALONE_UNLIKELY thwarts missing return statement analysis. In +// cases where the unlikely expression may be a constant, use this macro to +// ensure return statement analysis keeps working (at the cost of not getting +// the likely/unlikely annotation on nvcc). +// https://github.com/pytorch/pytorch/issues/21418 +// +// Currently, this is only used in the error reporting macros below. If you +// want to use it more generally, move me to Macros.h +// +// TODO: Brian Vaughan observed that we might be able to get this to work on +// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs +// from non-constexpr. Since there isn't any evidence that losing +// STANDALONE_UNLIKELY in nvcc is causing us perf problems, this is not yet +// implemented, but this might be an interesting piece of C++ code for an +// intrepid bootcamper to write. +#if defined(__CUDACC__) +#define STANDALONE_UNLIKELY_OR_CONST(e) e +#else +#define STANDALONE_UNLIKELY_OR_CONST(e) STANDALONE_UNLIKELY(e) +#endif + +#define STANDALONE_STRINGIZE_IMPL(x) #x +#define STANDALONE_STRINGIZE(x) STANDALONE_STRINGIZE_IMPL(x) + +#define STANDALONE_CONCATENATE_IMPL(s1, s2) s1##s2 +#define STANDALONE_CONCATENATE(s1, s2) STANDALONE_CONCATENATE_IMPL(s1, s2) + +/** + * STANDALONE_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts + * with str and ends with a unique number. + */ +#ifdef __COUNTER__ +#define STANDALONE_UID __COUNTER__ +#define STANDALONE_ANONYMOUS_VARIABLE(str) \ + STANDALONE_CONCATENATE(str, __COUNTER__) +#else +#define STANDALONE_UID __LINE__ +#define STANDALONE_ANONYMOUS_VARIABLE(str) STANDALONE_CONCATENATE(str, __LINE__) +#endif + +// Private helper macro for workaround MSVC misexpansion of nested macro +// invocations involving __VA_ARGS__. See +// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly +#define STANDALONE_EXPAND_MSVC_WORKAROUND(x) x + +/// STANDALONE_NOINLINE - Functions whose declaration is annotated with this +/// will not be inlined. +#ifdef __GNUC__ +#define STANDALONE_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define STANDALONE_NOINLINE __declspec(noinline) +#else +#define STANDALONE_NOINLINE +#endif + +#if defined(_MSC_VER) +#define STANDALONE_ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define STANDALONE_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define STANDALONE_ALWAYS_INLINE inline +#endif + +// Unlike STANDALONE_ALWAYS_INLINE, STANDALONE_ALWAYS_INLINE_ATTRIBUTE can be +// used on a lambda. +#if defined(_MSC_VER) +// MSVC 14.39 is reasonably recent and doesn't like +// [[msvc::forceinline]] on a lambda, so don't try to use it. +#define STANDALONE_ALWAYS_INLINE_ATTRIBUTE +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define STANDALONE_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) +#else +#define STANDALONE_ALWAYS_INLINE_ATTRIBUTE +#endif + +#if defined(_MSC_VER) +#define STANDALONE_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define STANDALONE_ATTR_VISIBILITY_HIDDEN \ + __attribute__((__visibility__("hidden"))) +#else +#define STANDALONE_ATTR_VISIBILITY_HIDDEN +#endif + +#define STANDALONE_ERASE \ + STANDALONE_ALWAYS_INLINE STANDALONE_ATTR_VISIBILITY_HIDDEN + +#include + +#ifdef __HIPCC__ +// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. +// We do this #include here so that STANDALONE_HOST_DEVICE and friends will Just +// Work. See https://github.com/ROCm/hip/issues/441 +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define STANDALONE_HOST_DEVICE __host__ __device__ +#define STANDALONE_DEVICE __device__ +#define STANDALONE_HOST __host__ +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. +#if __CUDA_ARCH__ == 750 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; +#else +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; +#endif +// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently +constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. +constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; +// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it +// turns out that although __launch_bounds__ can take constexpr, it +// can't take a constexpr that has anything to do with templates. +// Currently we use launch_bounds that depend on template arguments in +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, +// STANDALONE_MAX_THREADS_PER_BLOCK and STANDALONE_MIN_BLOCKS_PER_SM are +// kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(STANDALONE_MAX_THREADS_PER_BLOCK(a), +// STANDALONE_MIN_BLOCKS_PER_SM(a, b)), which will also properly respect limits +// on old architectures. +#define STANDALONE_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define STANDALONE_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ + (threads_per_block)))) +// STANDALONE_LAUNCH_BOUNDS is analogous to __launch_bounds__ +#define STANDALONE_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy + // and versatility across all architectures. +#define STANDALONE_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((STANDALONE_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define STANDALONE_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (STANDALONE_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (STANDALONE_MIN_BLOCKS_PER_SM( \ + (max_threads_per_block), (min_blocks_per_sm)))) +#else +#define STANDALONE_HOST_DEVICE +#define STANDALONE_HOST +#define STANDALONE_DEVICE +#endif + +#define _STANDALONE_PRAGMA__(string) _Pragma(#string) +#define _STANDALONE_PRAGMA_(string) _STANDALONE_PRAGMA__(string) + +#ifdef __clang__ +#define STANDALONE_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") +#define STANDALONE_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") +#define STANDALONE_CLANG_DIAGNOSTIC_IGNORE(flag) \ + _STANDALONE_PRAGMA_(clang diagnostic ignored flag) +#define STANDALONE_CLANG_HAS_WARNING(flag) __has_warning(flag) +#else +#define STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#define STANDALONE_CLANG_DIAGNOSTIC_POP() +#define STANDALONE_CLANG_DIAGNOSTIC_IGNORE(flag) +#define STANDALONE_CLANG_HAS_WARNING(flag) 0 +#endif diff --git a/backends/aoti/slim/c10/targets.bzl b/backends/aoti/slim/c10/targets.bzl index f12ab1009d8..2bef9f5cf96 100644 --- a/backends/aoti/slim/c10/targets.bzl +++ b/backends/aoti/slim/c10/targets.bzl @@ -1,25 +1,31 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") def define_common_targets(): - """Define c10 core targets for SlimTensor. + """Define c10 library targets for SlimTensor. - These headers provide c10 APIs needed by SlimTensor that are not - available in ExecuTorch's c10 directory (which is synced from PyTorch). + These are portable c10 utilities adapted from torchnative/standalone. """ + # c10 utility headers (ArrayRef, Half, BFloat16, complex, etc.) + # Excludes CUDA-specific headers which require CUDA SDK runtime.cxx_library( - name = "c10_core", - exported_headers = [ - "Contiguity.h", - "MemoryFormat.h", - "SizesAndStrides.h", - "WrapDimMinimal.h", - ], + name = "c10", + exported_headers = glob( + ["**/*.h"], + exclude = ["cuda/**/*.h"], + ), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [], + ) + + # c10 CUDA-specific headers (requires CUDA SDK) + runtime.cxx_library( + name = "c10_cuda", + exported_headers = glob(["cuda/*.h"]), visibility = ["@EXECUTORCH_CLIENTS"], - exported_deps = [ - "//executorch/runtime/core:core", - "//executorch/runtime/platform:platform", - ] + ([] if runtime.is_oss else [ - "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", - ]), + exported_preprocessor_flags = ["-DUSE_CUDA"], + exported_deps = [":c10"], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], ) diff --git a/backends/aoti/slim/c10/util/Array.h b/backends/aoti/slim/c10/util/Array.h new file mode 100644 index 00000000000..39eabc830d1 --- /dev/null +++ b/backends/aoti/slim/c10/util/Array.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace standalone::c10 { + +// This helper function creates a constexpr std::array +// From a compile time list of values, without requiring you to explicitly +// write out the length. +// +// See also https://stackoverflow.com/a/26351760/23845 +template +inline constexpr auto array_of(T&&... t) -> std::array { + return {{std::forward(t)...}}; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/ArrayRef.h b/backends/aoti/slim/c10/util/ArrayRef.h new file mode 100644 index 00000000000..4a09f7a9335 --- /dev/null +++ b/backends/aoti/slim/c10/util/ArrayRef.h @@ -0,0 +1,371 @@ +//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::ArrayRef. +// removed llvm-specific functionality +// removed some implicit const -> non-const conversions that rely on +// complicated std::enable_if meta-programming +// removed a bunch of slice variants for simplicity... + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace standalone::c10 { +/// ArrayRef - Represent a constant reference to an array (0 or more elements +/// consecutively in memory), i.e. a start pointer and a length. It allows +/// various APIs to take consecutive elements easily and conveniently. +/// +/// This class does not own the underlying data, it is expected to be used in +/// situations where the data resides in some other buffer, whose lifetime +/// extends past that of the ArrayRef. For this reason, it is not in general +/// safe to store an ArrayRef. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +template +class ArrayRef final { + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + void debugCheckNullptrInvariant() { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + Data != nullptr || Length == 0, + "created ArrayRef with nullptr and non-zero length! std::optional " + "relies on this being illegal"); + } + + public: + /// @name Constructors + /// @{ + + /// Construct an empty ArrayRef. + /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an ArrayRef from a single element. + // TODO Make this explicit + constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an ArrayRef from a pointer and length. + constexpr ArrayRef(const T* data, size_t length) + : Data(data), Length(length) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a range. + constexpr ArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) { + debugCheckNullptrInvariant(); + } + + template < + typename Container, + typename U = decltype(std::declval().data()), + typename = std::enable_if_t< + (std::is_same_v || std::is_same_v)>> + /* implicit */ ArrayRef(const Container& container) + : Data(container.data()), Length(container.size()) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ ArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same_v, + "ArrayRef cannot be constructed from a " + "std::vector bitfield."); + } + + /// Construct an ArrayRef from a std::array + template + /* implicit */ constexpr ArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct an ArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-arrays*) + /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} + + /// Construct an ArrayRef from a std::initializer_list. + /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return Data; + } + constexpr iterator end() const { + return Data + Length; + } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return Data; + } + constexpr const_iterator cend() const { + return Data + Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// Check if all elements in the array satisfy the given expression + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return Length == 0; + } + + constexpr const T* data() const { + return Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return Length; + } + + /// front - Get the first element. + constexpr const T& front() const { + STANDALONE_CHECK( + !empty(), "ArrayRef: attempted to access front() of empty list"); + return Data[0]; + } + + /// back - Get the last element. + constexpr const T& back() const { + STANDALONE_CHECK( + !empty(), "ArrayRef: attempted to access back() of empty list"); + return Data[Length - 1]; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(ArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + constexpr ArrayRef slice(size_t N, size_t M) const { + STANDALONE_CHECK( + N + M <= size(), + "ArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + size()); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr ArrayRef slice(size_t N) const { + STANDALONE_CHECK( + N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); + return slice(N, size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return Data[Index]; + } + + /// Vector compatibility + constexpr const T& at(size_t Index) const { + STANDALONE_CHECK( + Index < Length, + "ArrayRef: invalid index Index = ", + Index, + "; Length = ", + Length); + return Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { + return std::vector(Data, Data + Length); + } + + /// @} +}; + +template +std::ostream& operator<<(std::ostream& out, ArrayRef list) { + int i = 0; + out << "["; + for (const auto& e : list) { + if (i++ > 0) + out << ", "; + out << e; + } + out << "]"; + return out; +} + +/// @name ArrayRef Convenience constructors +/// @{ + +/// Construct an ArrayRef from a single element. +template +ArrayRef makeArrayRef(const T& OneElt) { + return OneElt; +} + +/// Construct an ArrayRef from a pointer and length. +template +ArrayRef makeArrayRef(const T* data, size_t length) { + return ArrayRef(data, length); +} + +/// Construct an ArrayRef from a range. +template +ArrayRef makeArrayRef(const T* begin, const T* end) { + return ArrayRef(begin, end); +} + +/// Construct an ArrayRef from a std::vector. +template +ArrayRef makeArrayRef(const std::vector& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::array. +template +ArrayRef makeArrayRef(const std::array& Arr) { + return Arr; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) (const) +template +ArrayRef makeArrayRef(const ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) +template +ArrayRef& makeArrayRef(ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a C array. +template +// NOLINTNEXTLINE(*c-arrays*) +ArrayRef makeArrayRef(const T (&Arr)[N]) { + return ArrayRef(Arr); +} + +// WARNING: Template instantiation will NOT be willing to do an implicit +// conversions to get you to an standalone::c10::ArrayRef, which is why we +// need so many overloads. + +template +bool operator==( + standalone::c10::ArrayRef a1, + standalone::c10::ArrayRef a2) { + return a1.equals(a2); +} + +template +bool operator!=( + standalone::c10::ArrayRef a1, + standalone::c10::ArrayRef a2) { + return !a1.equals(a2); +} + +template +bool operator==(const std::vector& a1, standalone::c10::ArrayRef a2) { + return standalone::c10::ArrayRef(a1).equals(a2); +} + +template +bool operator!=(const std::vector& a1, standalone::c10::ArrayRef a2) { + return !standalone::c10::ArrayRef(a1).equals(a2); +} + +template +bool operator==(standalone::c10::ArrayRef a1, const std::vector& a2) { + return a1.equals(standalone::c10::ArrayRef(a2)); +} + +template +bool operator!=(standalone::c10::ArrayRef a1, const std::vector& a2) { + return !a1.equals(standalone::c10::ArrayRef(a2)); +} + +using IntArrayRef = ArrayRef; + +using IntList + [[deprecated("This alias is deprecated because it doesn't make ownership " + "semantics obvious. Use IntArrayRef instead!")]] = + ArrayRef; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/BFloat16-inl.h b/backends/aoti/slim/c10/util/BFloat16-inl.h new file mode 100644 index 00000000000..4608d9a6c54 --- /dev/null +++ b/backends/aoti/slim/c10/util/BFloat16-inl.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include + +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace standalone::c10 { + +/// Constructors +inline STANDALONE_HOST_DEVICE BFloat16::BFloat16(float value) + : +#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 800 + x(__bfloat16_as_ushort(__float2bfloat16(value))) +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + x(standalone::c10::bit_cast(sycl::ext::oneapi::bfloat16(value))) +#else + // RNE by default + x(detail::round_to_nearest_even(value)) +#endif +{ +} + +/// Implicit conversions +inline STANDALONE_HOST_DEVICE BFloat16::operator float() const { +#if defined(__CUDACC__) && !defined(USE_ROCM) + return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + return float(*reinterpret_cast(&x)); +#else + return detail::f32_from_bits(x); +#endif +} + +#if defined(__CUDACC__) && !defined(USE_ROCM) +inline STANDALONE_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +inline STANDALONE_HOST_DEVICE BFloat16::BFloat16( + const sycl::ext::oneapi::bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() + const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline STANDALONE_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE BFloat16 +operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 +operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 +operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator/( + const BFloat16& a, + const BFloat16& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator+=( + BFloat16& a, + const BFloat16& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator-=( + BFloat16& a, + const BFloat16& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator*=( + BFloat16& a, + const BFloat16& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator/=( + BFloat16& a, + const BFloat16& b) { + a = a / b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator|( + BFloat16& a, + const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator^( + BFloat16& a, + const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator&( + BFloat16& a, + const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline STANDALONE_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline STANDALONE_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr standalone::c10::BFloat16 min() { + return standalone::c10::BFloat16( + 0x0080, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 lowest() { + return standalone::c10::BFloat16( + 0xFF7F, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 max() { + return standalone::c10::BFloat16( + 0x7F7F, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 epsilon() { + return standalone::c10::BFloat16( + 0x3C00, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 round_error() { + return standalone::c10::BFloat16( + 0x3F00, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 infinity() { + return standalone::c10::BFloat16( + 0x7F80, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 quiet_NaN() { + return standalone::c10::BFloat16( + 0x7FC0, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 signaling_NaN() { + return standalone::c10::BFloat16( + 0x7F80, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 denorm_min() { + return standalone::c10::BFloat16( + 0x0001, standalone::c10::BFloat16::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/BFloat16-math.h b/backends/aoti/slim/c10/util/BFloat16-math.h new file mode 100644 index 00000000000..f036f309e26 --- /dev/null +++ b/backends/aoti/slim/c10/util/BFloat16-math.h @@ -0,0 +1,332 @@ +#pragma once + +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace standalone::c10 { +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || + std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; +} // namespace standalone::c10 + +namespace std { + +#if !defined(FBCODE_CAFFE2) && !defined(STANDALONE_NODEPRECATED) +using standalone::c10::is_reduced_floating_point; +using standalone::c10::is_reduced_floating_point_v; +#endif // !defined(FBCODE_CAFFE2) && !defined(STANDALONE_NODEPRECATED) + +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log(T a) { + return std::log(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template < + typename T, + typename std:: + enable_if_t, int> = 0> +STANDALONE_HOST_DEVICE inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + constexpr uint8_t bits = 16; + union { + T f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/BFloat16.h b/backends/aoti/slim/c10/util/BFloat16.h new file mode 100644 index 00000000000..ed6d07f53d0 --- /dev/null +++ b/backends/aoti/slim/c10/util/BFloat16.h @@ -0,0 +1,123 @@ +#pragma once + +// Defines the bloat16 type (brain floating-point). This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. + +#include +#include +#include +#include +#include +#include + +#if defined(__CUDACC__) && !defined(USE_ROCM) +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace standalone::c10 { + +namespace detail { +inline STANDALONE_HOST_DEVICE float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + +#if defined(USE_ROCM) && defined(__HIPCC__) + float* tempRes; + + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&tmp); + res = *tempRes; +#else + std::memcpy(&res, &tmp, sizeof(tmp)); +#endif + + return res; +} + +inline STANDALONE_HOST_DEVICE uint16_t bits_from_f32(float src) { + uint32_t res = 0; + +#if defined(USE_ROCM) && defined(__HIPCC__) + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + uint32_t* tempRes = reinterpret_cast(&src); + res = *tempRes; +#else + std::memcpy(&res, &src, sizeof(res)); +#endif + + return res >> 16; +} + +inline STANDALONE_HOST_DEVICE uint16_t round_to_nearest_even(float src) { +#if defined(USE_ROCM) && defined(__HIPCC__) + if (src != src) { +#elif defined(_MSC_VER) + if (isnan(src)) { +#else + if (std::isnan(src)) { +#endif + return UINT16_C(0x7FC0); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + union { + uint32_t U32; // NOLINT(facebook-hte-BadMemberName) + float F32; // NOLINT(facebook-hte-BadMemberName) + }; + + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); + } +} +} // namespace detail + +struct alignas(2) BFloat16 { + uint16_t x; + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) && defined(__HIPCC__) + STANDALONE_HOST_DEVICE BFloat16() = default; +#else + BFloat16() = default; +#endif + + struct from_bits_t {}; + static constexpr STANDALONE_HOST_DEVICE from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr STANDALONE_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) + : x(bits) {} + /* implicit */ inline STANDALONE_HOST_DEVICE BFloat16(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + +#if defined(__CUDACC__) && !defined(USE_ROCM) + inline STANDALONE_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline STANDALONE_HOST_DEVICE operator __nv_bfloat16() const; +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + inline STANDALONE_HOST_DEVICE BFloat16( + const sycl::ext::oneapi::bfloat16& value); + explicit inline STANDALONE_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() + const; +#endif +}; + +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Exception.h b/backends/aoti/slim/c10/util/Exception.h new file mode 100644 index 00000000000..6ab2bd8aae6 --- /dev/null +++ b/backends/aoti/slim/c10/util/Exception.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include +#include + +// In the standalone version, STANDALONE_CHECK throws std::runtime_error +// instead of standalone::c10::Error. +namespace standalone::c10::detail { +template +std::string torchCheckMsgImpl(const char* /*msg*/, const Args&... args) { + // This is similar to the one in c10/util/Exception.h, but does + // not depend on the more complex c10::str() function. + // ostringstream may support less data types than c10::str(), + // but should be sufficient in the standalone world. + std::ostringstream oss; + ((oss << args), ...); + return oss.str(); +} +inline const char* torchCheckMsgImpl(const char* msg) { + return msg; +} +// If there is just 1 user-provided C-string argument, use it. +inline const char* torchCheckMsgImpl(const char* /*msg*/, const char* args) { + return args; +} +} // namespace standalone::c10::detail + +#define STANDALONE_CHECK_MSG(cond, type, ...) \ + (::standalone::c10::detail::torchCheckMsgImpl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)", \ + ##__VA_ARGS__)) +#define STANDALONE_CHECK(cond, ...) \ + if (STANDALONE_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(STANDALONE_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + ##__VA_ARGS__)); \ + } +#define STANDALONE_INTERNAL_ASSERT(cond, ...) \ + if (STANDALONE_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(STANDALONE_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + #cond, \ + " INTERNAL ASSERT FAILED: ", \ + ##__VA_ARGS__)); \ + } + +#define WARNING_MESSAGE_STRING(...) \ + ::standalone::c10::detail::torchCheckMsgImpl(__VA_ARGS__) + +#ifdef DISABLE_WARN +#define _STANDALONE_WARN_WITH(...) ((void)0); +#else +#define _STANDALONE_WARN_WITH(...) \ + std::cerr << __func__ << ", " << __FILE__ << ":" << __LINE__ << ", " \ + << WARNING_MESSAGE_STRING(__VA_ARGS__) << std::endl; +#endif + +#define STANDALONE_WARN(...) _STANDALONE_WARN_WITH(__VA_ARGS__); + +#ifdef NDEBUG +// Optimized version - generates no code. +#define STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + while (false) \ + STANDALONE_EXPAND_MSVC_WORKAROUND(STANDALONE_INTERNAL_ASSERT(__VA_ARGS__)) +#else +#define STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + STANDALONE_EXPAND_MSVC_WORKAROUND(STANDALONE_INTERNAL_ASSERT(__VA_ARGS__)) +#endif diff --git a/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h b/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h new file mode 100644 index 00000000000..600e281b583 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h @@ -0,0 +1,28 @@ +#pragma once +#include + +#include + +/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed +/// into one byte). This is the FP4 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.3.3) +/// +/// Given two high precision values val0 and val1, here is the +/// binary configuration of their packed representation, from MSB to LSB: +/// +/// original value | val1 : val0 +/// ======================================== +/// bit index (MSB==7, LSB==0) | 7654 : 3210 +/// sign/exponent/mantissa | seem : seem +/// + +namespace standalone::c10 { + +struct alignas(1) Float4_e2m1fn_x2 { + uint8_t val_; + Float4_e2m1fn_x2() = default; + STANDALONE_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h b/backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h new file mode 100644 index 00000000000..cc31b82e699 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h @@ -0,0 +1,297 @@ +#pragma once + +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) + : x(detail::fp8e4m3fn_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn::operator float() const { + return detail::fp8e4m3fn_to_fp32_value(x); +} + +/// Special values helper + +inline STANDALONE_HOST_DEVICE bool Float8_e4m3fn::isnan() const { + return (x & 0b01111111) == 0b01111111; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator/( + const Float8_e4m3fn& a, + const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator+=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator-=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator*=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator/=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e4m3fn a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e4m3fn& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e4m3fn& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e4m3fn& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e4m3fn& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e4m3fn a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator+(Float8_e4m3fn a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator-(Float8_e4m3fn a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator*(Float8_e4m3fn a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator/(Float8_e4m3fn a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator+(int64_t a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator-(int64_t a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator*(int64_t a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator/(int64_t a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e4m3fn to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr standalone::c10::Float8_e4m3fn min() { + return standalone::c10::Float8_e4m3fn( + 0x08, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn lowest() { + return standalone::c10::Float8_e4m3fn( + 0xFE, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn max() { + return standalone::c10::Float8_e4m3fn( + 0x7E, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn epsilon() { + return standalone::c10::Float8_e4m3fn( + 0x20, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn round_error() { + return standalone::c10::Float8_e4m3fn( + 0x30, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn quiet_NaN() { + return standalone::c10::Float8_e4m3fn( + 0x7F, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn denorm_min() { + return standalone::c10::Float8_e4m3fn( + 0x01, standalone::c10::Float8_e4m3fn::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fn.h b/backends/aoti/slim/c10/util/Float8_e4m3fn.h new file mode 100644 index 00000000000..320a677cbbb --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fn.h @@ -0,0 +1,238 @@ +#pragma once + +/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// bias = 7 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/standalone/c10/util/Half.h + +#include +#include + +#if defined(__cplusplus) +#include +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline STANDALONE_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E4M3FN number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)input << 24; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(nonsign); +#elif defined(__SYCL_DEVICE_ONLY__) + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#elif defined(_MSC_VER) && !defined(__clang__) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#endif + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + /* + * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1, + * the addition overflows it into bit 31, and the subsequent shift turns the + * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number + * is Nan, 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 4 so the exponent (4 bits originally) + * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0x07 + * for fp8e4m3fn number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x78, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { + /* + * Binary representation of 480.0f, which is the first value + * not representable in fp8e4m3fn range: + * 0 1111 111 - fp8e4m3fn + * 0 10000111 11100000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(1087) << 20; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fn normal range + * into denorm representation + * magic number: ((127 - 7) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(141) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = 0x7f; + } else { + if (f_bits < (UINT32_C(121) << 23)) { + // Input number is smaller than 2^(-6), which is the smallest + // fp8e4m3fn normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fn { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fn() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e4m3fn(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h b/backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h new file mode 100644 index 00000000000..55a6ce73972 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h @@ -0,0 +1,312 @@ +#pragma once + +#include +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) + : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<4, 3>(x); +} + +/// Special values helper + +inline STANDALONE_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { + return x == 0b10000000; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz operator/( + const Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator+=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator-=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator*=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator/=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e4m3fnuz& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e4m3fnuz& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e4m3fnuz& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e4m3fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(Float8_e4m3fnuz a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(Float8_e4m3fnuz a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(Float8_e4m3fnuz a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(Float8_e4m3fnuz a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(int a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(int a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(int a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(int a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(Float8_e4m3fnuz a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(Float8_e4m3fnuz a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(Float8_e4m3fnuz a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(Float8_e4m3fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e4m3fnuz to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr standalone::c10::Float8_e4m3fnuz min() { + return standalone::c10::Float8_e4m3fnuz( + 0x08, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz lowest() { + return standalone::c10::Float8_e4m3fnuz( + 0xFF, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz max() { + return standalone::c10::Float8_e4m3fnuz( + 0x7F, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz epsilon() { + return standalone::c10::Float8_e4m3fnuz( + 0x28, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz round_error() { + return standalone::c10::Float8_e4m3fnuz( + 0x38, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz infinity() { + // NaN (no infinities) + return standalone::c10::Float8_e4m3fnuz( + 0x80, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz quiet_NaN() { + return standalone::c10::Float8_e4m3fnuz( + 0x80, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz denorm_min() { + return standalone::c10::Float8_e4m3fnuz( + 0x01, standalone::c10::Float8_e4m3fnuz::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h b/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h new file mode 100644 index 00000000000..ff3c050f018 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h @@ -0,0 +1,138 @@ +#pragma once + +/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as Float8_e4m3fn: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// The key differences versus Float8_e4m3fn are: +/// bias = 8 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { + /* + * Binary representation of 256.0f, which is the first value not representable + * (i.e. the first value which would overflow in to the sign bit, resulting in + * a NaN) in fp8e4m3fnuz range: + * 1 0000 000 - fp8e4m3fnuz + * 0 10000111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range + * into denorm representation + * magic number: ((127 - 8) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s. + return 0x80; + } + + if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { + // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fnuz() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e4m3fnuz& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e5m2-inl.h b/backends/aoti/slim/c10/util/Float8_e5m2-inl.h new file mode 100644 index 00000000000..c8e90a8aa0d --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2-inl.h @@ -0,0 +1,302 @@ +#pragma once + +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#define EXP_WIDTH_FP8 5 +#define MAN_WIDTH_FP8 2 +#define EXP_BIAS_FP8 15 + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) + : x(detail::fp8e5m2_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e5m2::operator float() const { + return detail::fp8e5m2_to_fp32_value(x); +} + +/// Special values helpers + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2::isnan() const { + return (x & 0b01111111) > 0b01111100; +} + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2::isinf() const { + return (x & 0b01111111) == 0b01111100; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e5m2 +operator+(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 +operator-(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 +operator*(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/( + const Float8_e5m2& a, + const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator+=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator-=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator*=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator/=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e5m2 a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e5m2 a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e5m2 a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e5m2 a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e5m2& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e5m2& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e5m2& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e5m2& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e5m2 a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e5m2 a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e5m2 a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e5m2 a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e5m2 to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr standalone::c10::Float8_e5m2 min() { + return standalone::c10::Float8_e5m2( + 0x4, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 max() { + return standalone::c10::Float8_e5m2( + 0x7B, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 lowest() { + return standalone::c10::Float8_e5m2( + 0xFB, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 epsilon() { + return standalone::c10::Float8_e5m2( + 0x34, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 round_error() { + return standalone::c10::Float8_e5m2( + 0x38, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 infinity() { + return standalone::c10::Float8_e5m2( + 0x7C, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 quiet_NaN() { + return standalone::c10::Float8_e5m2( + 0x7F, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 denorm_min() { + return standalone::c10::Float8_e5m2( + 0x01, standalone::c10::Float8_e5m2::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e5m2.h b/backends/aoti/slim/c10/util/Float8_e5m2.h new file mode 100644 index 00000000000..88d1aab0525 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2.h @@ -0,0 +1,147 @@ +#pragma once + +/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// bias = 15 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/standalone/c10/util/Half.h + +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline STANDALONE_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E5M2 number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEEE|MM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 26-30 24-25 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + uint16_t half_representation = input; + half_representation <<= 8; + return fp16_ieee_to_fp32_value(half_representation); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) { + /* + * Binary representation of fp32 infinity + * 0 11111111 00000000000000000000000 + */ + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + + /* + * Binary representation of 65536.0f, which is the first value + * not representable in fp8e5m2 range: + * 0 11111 00 - fp8e5m2 + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2 normal range + * into denorm representation + * magic number: ((127 - 15) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); + } else { + if (f_bits < (UINT32_C(113) << 23)) { + // Input number is smaller than 2^(-14), which is the smallest + // fp8e5m2 normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint32_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2 { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e5m2(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; + inline STANDALONE_HOST_DEVICE bool isinf() const; +}; + +inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h b/backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h new file mode 100644 index 00000000000..d2ccac329af --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h @@ -0,0 +1,318 @@ +#pragma once + +#include +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) + : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<5, 2>(x); +} + +/// Special values helpers + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { + return x == 0b10000000; +} + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { + return false; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz operator/( + const Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator+=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator-=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator*=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator/=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e5m2fnuz& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e5m2fnuz& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e5m2fnuz& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e5m2fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(Float8_e5m2fnuz a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(Float8_e5m2fnuz a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(Float8_e5m2fnuz a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(Float8_e5m2fnuz a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(int a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(int a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(int a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(int a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(Float8_e5m2fnuz a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(Float8_e5m2fnuz a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(Float8_e5m2fnuz a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(Float8_e5m2fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e5m2fnuz to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr standalone::c10::Float8_e5m2fnuz min() { + return standalone::c10::Float8_e5m2fnuz( + 0x04, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz max() { + return standalone::c10::Float8_e5m2fnuz( + 0x7F, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz lowest() { + return standalone::c10::Float8_e5m2fnuz( + 0xFF, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz epsilon() { + return standalone::c10::Float8_e5m2fnuz( + 0x34, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz round_error() { + return standalone::c10::Float8_e5m2fnuz( + 0x38, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz infinity() { + return standalone::c10::Float8_e5m2fnuz( + 0x80, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + // TODO(future): we are mapping neg_zero to both inf and NaN, this is + // surprising and we should figure out what to do about it. + static constexpr standalone::c10::Float8_e5m2fnuz quiet_NaN() { + return standalone::c10::Float8_e5m2fnuz( + 0x80, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz denorm_min() { + return standalone::c10::Float8_e5m2fnuz( + 0x01, standalone::c10::Float8_e5m2fnuz::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e5m2fnuz.h b/backends/aoti/slim/c10/util/Float8_e5m2fnuz.h new file mode 100644 index 00000000000..c16e5613202 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2fnuz.h @@ -0,0 +1,138 @@ +#pragma once + +/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as e5m2: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// The key differences that e5m2fnuz brings are: +/// bias = 16 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { + /* + * Binary representation of 65536.0f, which is the first value not + * representable (i.e. the first value which would overflow in to the sign + * bit, resulting in a NaN) in fp8e4m3fnuz range: + * 1 00000 00 - fp8e5m2fnuz + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range + * into denormalized representation. + * magic number: ((127 - 16) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s + return 0x80; + } + + if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { + // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2fnuz() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; + inline STANDALONE_HOST_DEVICE bool isinf() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e5m2fnuz& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h b/backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h new file mode 100644 index 00000000000..f510ca551b8 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include + +// TODO(#146647): Can we remove the below warning? +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value) + : x(detail::fp8e8m0fnu_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e8m0fnu::operator float() const { + // TODO(#146647): maybe rewrite without control flow + + // if exponent is zero, need to special case to return 2^-127 instead of zero + if (x == 0) { + return standalone::c10::detail::fp32_from_bits(0x00400000); + } + + // if exponent is NaN, need to special case to return properly encoded NaN + if (isnan()) { + return standalone::c10::detail::fp32_from_bits(0x7f800001); + } + + // leave sign at 0, set the exponent bits, leave stored mantissa at 0 + uint32_t res = x << 23; + + return standalone::c10::detail::fp32_from_bits(res); +} + +/// Special values helper + +inline STANDALONE_HOST_DEVICE bool Float8_e8m0fnu::isnan() const { + return x == 0b11111111; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e8m0fnu to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = false; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = false; + static constexpr auto has_denorm_loss = false; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 1; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 1; // just a 2! + static constexpr int radix = 2; + static constexpr int min_exponent = -126; + static constexpr int min_exponent10 = -38; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr standalone::c10::Float8_e8m0fnu min() { + // 2^-127 + return standalone::c10::Float8_e8m0fnu( + 0b00000000, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu lowest() { + // 2^-127 + return standalone::c10::Float8_e8m0fnu( + 0b00000000, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu max() { + // 254 biased, which is 127 unbiased, so 2^127 + return standalone::c10::Float8_e8m0fnu( + 0b11111110, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu epsilon() { + // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this + // is "the difference between 1.0 and the next representable value of the + // given floating-point type". The next representable value is 2.0, so the + // difference is 1.0 which is 2^0. 0 unbiased is 127 biased. + return standalone::c10::Float8_e8m0fnu( + 0b01111111, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu round_error() { + // 0.5 in float, which is 2^-1, and -1 + 127 = 126 + return standalone::c10::Float8_e8m0fnu( + 0b01111110, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu quiet_NaN() { + return standalone::c10::Float8_e8m0fnu( + 0b11111111, standalone::c10::Float8_e8m0fnu::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e8m0fnu.h b/backends/aoti/slim/c10/util/Float8_e8m0fnu.h new file mode 100644 index 00000000000..2e2e46d627a --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e8m0fnu.h @@ -0,0 +1,119 @@ +#pragma once + +/// Defines the Float8_e8m0fnu type (8-bit floating-point) including +/// conversions to standard C types +/// Binary configuration : +/// eeeeeeee +/// no sign bits +/// 8 exponent bits +/// no mantissa bits +/// +/// This is the E8M0 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.4.1) + +#include +#include +#include + +// TODO(#146647): do we need to special case OPENCL? +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 e8m0fnu format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) { + // TODO(#146647): maybe rewrite without control flow + + uint32_t f_bits = standalone::c10::detail::fp32_to_bits(f); + + // extract the exponent + uint32_t exponent = (f_bits >> 23) & 0b11111111; + + // special case float32 NaN and +-inf to map to e8m0 nan + if (exponent == 0b11111111) { + return exponent; + } + + // next, we use guard, round, sticky bits and the LSB to implement round to + // nearest, with ties to even + + // guard bit - bit 23, or 22 zero-indexed + uint8_t g = (f_bits & 0x400000) > 0; + // round bit - bit 22, or 21 zero-indexed + uint8_t r = (f_bits & 0x200000) > 0; + // sticky bit - bits 21 to 1, or 20 to 0 zero-indexed + uint8_t s = (f_bits & 0x1FFFFF) > 0; + // in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the + // original float32 is denormal, and to 1 if the original float32 is normal. + uint8_t lsb = exponent > 0; + + // implement the RNE logic + bool round_up = false; + + // if g == 0, round down (no-op) + if (g == 1) { + if ((r == 1) || (s == 1)) { + // round up + round_up = true; + } else { + if (lsb == 1) { + // round up + round_up = true; + } + // if lsb == 0, round down (no-op) + } + } + + if (round_up) { + // adjust exponent + // note that if exponent was 255 we would have already returned earlier, so + // we know we can add one safely without running out of bounds + exponent++; + } + + return exponent; +} + +} // namespace detail + +struct alignas(1) Float8_e8m0fnu { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e8m0fnu() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e8m0fnu(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e8m0fnu& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_fnuz_cvt.h b/backends/aoti/slim/c10/util/Float8_fnuz_cvt.h new file mode 100644 index 00000000000..00bfa8cd8fc --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_fnuz_cvt.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +#include + +#if defined(SYCL_LANGUAGE_VERSION) +#include +#endif + +namespace standalone::c10::detail { + +/* + * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ + * format, in bit representation, to a 32-bit floating-point number. + */ +template +inline STANDALONE_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { + static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2)); + constexpr uint32_t weo = 8; + constexpr uint32_t wmo = 23; + + if (x == 0) { + return 0; + } + + if (x == 0x80) { + constexpr uint32_t ifNaN = 0x7F800001; + return fp32_from_bits(ifNaN); + } + + uint32_t mantissa = x & ((1 << wm) - 1); + uint32_t exponent = (x & 0x7F) >> wm; + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(mantissa); +#elif defined(__SYCL_DEVICE_ONLY__) + uint32_t renorm_shift = sycl::clz(mantissa); +#elif defined(_MSC_VER) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(mantissa); +#endif + uint32_t sh = 1 + renorm_shift - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + + const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)); + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + uint32_t sign = x >> 7; + uint32_t retval = (sign << 31) | (exponent << 23) | mantissa; + return fp32_from_bits(retval); +} + +} // namespace standalone::c10::detail diff --git a/backends/aoti/slim/c10/util/Half-inl.h b/backends/aoti/slim/c10/util/Half-inl.h new file mode 100644 index 00000000000..05fa6349f81 --- /dev/null +++ b/backends/aoti/slim/c10/util/Half-inl.h @@ -0,0 +1,351 @@ +#pragma once + +#include +#include + +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +// TODO: add contents in ATen/cpu/vec/vec_half.h +// #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ +// !defined(__APPLE__) +// #include +// #endif + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +#if defined(__aarch64__) && !defined(__CUDACC__) +/// Constructors +inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {} +inline Half::operator float16_t() const { + return detail::fp16_from_bits(x); +} +#else + +inline STANDALONE_HOST_DEVICE Half::Half(float value) + : +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + x(__half_as_short(__float2half(value))) +#elif defined(__SYCL_DEVICE_ONLY__) + x(standalone::c10::bit_cast(sycl::half(value))) +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + x(at::vec::float2half_scalar(value)) +#else + x(detail::fp16_ieee_from_fp32_value(value)) +#endif +{ +} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Half::operator float() const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __half2float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) + return float(standalone::c10::bit_cast(x)); +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + return at::vec::half2float_scalar(x); +#elif defined(__aarch64__) && !defined(__CUDACC__) + return detail::native_fp16_to_fp32_value(x); +#else + return detail::fp16_ieee_to_fp32_value(x); +#endif +} + +#endif /* !defined(__aarch64__) || defined(__CUDACC__) \ + */ + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline STANDALONE_HOST_DEVICE Half::Half(const __half& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE Half::operator __half() const { + return *reinterpret_cast(&x); +} +#endif + +#ifdef SYCL_LANGUAGE_VERSION +inline STANDALONE_HOST_DEVICE Half::Half(const sycl::half& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE Half::operator sycl::half() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ + (defined(__clang__) && defined(__CUDA__)) +inline __device__ Half __ldg(const Half* ptr) { + return __ldg(reinterpret_cast(ptr)); +} +#endif + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Half operator+(const Half& a, const Half& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator-(const Half& a, const Half& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator*(const Half& a, const Half& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator/(const Half& a, const Half& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator-(const Half& a) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) + return __hneg(a); +#elif defined(__SYCL_DEVICE_ONLY__) + return -standalone::c10::bit_cast(a); +#else + return -static_cast(a); +#endif +} + +inline STANDALONE_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Half a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Half a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Half a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Half a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Half b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Half b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Half b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=(float& a, const Half& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=(float& a, const Half& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=(float& a, const Half& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=(float& a, const Half& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Half a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Half a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Half a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Half a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Half b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Half b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Half b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Half operator+(Half a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator-(Half a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator*(Half a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator/(Half a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator+(int a, Half b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Half operator-(int a, Half b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Half operator*(int a, Half b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Half operator/(int a, Half b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Half operator+(Half a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator-(Half a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator*(Half a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator/(Half a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator+(int64_t a, Half b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Half operator-(int64_t a, Half b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Half operator*(int64_t a, Half b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Half operator/(int64_t a, Half b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Half to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + static constexpr standalone::c10::Half min() { + return standalone::c10::Half(0x0400, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half lowest() { + return standalone::c10::Half(0xFBFF, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half max() { + return standalone::c10::Half(0x7BFF, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half epsilon() { + return standalone::c10::Half(0x1400, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half round_error() { + return standalone::c10::Half(0x3800, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half infinity() { + return standalone::c10::Half(0x7C00, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half quiet_NaN() { + return standalone::c10::Half(0x7E00, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half signaling_NaN() { + return standalone::c10::Half(0x7D00, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half denorm_min() { + return standalone::c10::Half(0x0001, standalone::c10::Half::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Half.h b/backends/aoti/slim/c10/util/Half.h new file mode 100644 index 00000000000..86f8d8683e0 --- /dev/null +++ b/backends/aoti/slim/c10/util/Half.h @@ -0,0 +1,424 @@ +#pragma once + +/// Defines the Half type (half-precision floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32, instead of using CUDA half intrinsics. +/// Most uses of this type within ATen are memory bound, including the +/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. +/// If you are writing a compute bound kernel, you can use the CUDA half +/// intrinsics directly on the Half type from device code. + +#include +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +#if defined(__aarch64__) && !defined(__CUDACC__) +#include +#endif + +#if defined(__GNUC__) || defined(__clang__) +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \ + defined(_M_IX86) +#if defined(__F16C__) && \ + !(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \ + defined(__HIP_DEVICE_COMPILE__)) +#define STANDALONE_X86_F16 1 +#include // import conversion ops from f16cintrin.h +#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__) + // || defined(__HIP_DEVICE_COMPILE__)) +#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 +#endif // __GNUC__ || __clang__ + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +STANDALONE_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { +#ifdef STANDALONE_X86_F16 + return _cvtsh_ss(h); +#else + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits + * of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. Therefore, if the biased + * exponent of the half-precision input was 0x1F (max possible value), the + * biased exponent of the single-precision output must be 0xFF (max possible + * value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested by the difference in the exponent bias + * (see above). + * - Then we multiply the single-precision result of exponent adjustment by + * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the + * necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least partially IEEE754-compliant + * implementations. + * + * Note that the above operations do not handle denormal inputs (where biased + * exponent == 0). However, they also do not operate on denormal inputs, and + * do not produce denormal results. + */ + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val = 0; +#if defined(_MSC_VER) && defined(__clang__) + __builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); +#else + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); +#endif + + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input and with an exponent which would + * scale the corresponding mantissa bits to 2**(-24). A normalized + * single-precision floating-point number is represented as: FP32 = (1 + + * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased + * exponent is 126, a unit change in the mantissa of the input denormalized + * half-precision number causes a change of the constructed single-precision + * number by 2**(-24), i.e. the same amount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number is zero, the constructed + * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = + * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed + * single-precision number to get the numerical equivalent of the input + * half-precision number. + */ + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the input exponent. The variable + * two_w contains input exponent in bits 27-31, therefore if its smaller than + * 2**27, the input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +#endif // STANDALONE_X86_F16 +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 16-bit floating-point number in IEEE half-precision format, in bit + * representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline uint16_t fp16_ieee_from_fp32_value(float f) { +#ifdef STANDALONE_X86_F16 + return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT); +#else + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val = 0, scale_to_zero_val = 0; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy( + &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | + (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +#endif // STANDALONE_X86_F16 +} + +#ifdef STANDALONE_X86_F16 +#undef STANDALONE_X86_F16 +#endif // STANDALONE_X86_F16 + +#if defined(__aarch64__) && !defined(__CUDACC__) +inline float16_t fp16_from_bits(uint16_t h) { + return standalone::c10::bit_cast(h); +} + +inline uint16_t fp16_to_bits(float16_t f) { + return standalone::c10::bit_cast(f); +} + +// According to https://godbolt.org/z/frExdbsWG it would translate to single +// fcvt s0, h0 +inline float native_fp16_to_fp32_value(uint16_t h) { + return static_cast(fp16_from_bits(h)); +} + +inline uint16_t native_fp16_from_fp32_value(float f) { + return fp16_to_bits(static_cast(f)); +} +#endif + +} // namespace detail + +struct alignas(2) Half { + unsigned short x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) + STANDALONE_HOST_DEVICE Half() = default; +#else + Half() = default; +#endif + + constexpr STANDALONE_HOST_DEVICE Half(unsigned short bits, from_bits_t) + : x(bits) {} +#if defined(__aarch64__) && !defined(__CUDACC__) + inline Half(float16_t value); + inline operator float16_t() const; +#else + inline STANDALONE_HOST_DEVICE Half(float value); + inline STANDALONE_HOST_DEVICE operator float() const; +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline STANDALONE_HOST_DEVICE Half(const __half& value); + inline STANDALONE_HOST_DEVICE operator __half() const; +#endif +#ifdef SYCL_LANGUAGE_VERSION + inline STANDALONE_HOST_DEVICE Half(const sycl::half& value); + inline STANDALONE_HOST_DEVICE operator sycl::half() const; +#endif +}; + +inline std::ostream& operator<<(std::ostream& out, const Half& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/StringUtil.h b/backends/aoti/slim/c10/util/StringUtil.h new file mode 100644 index 00000000000..ff7c591e734 --- /dev/null +++ b/backends/aoti/slim/c10/util/StringUtil.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace standalone::c10 { +template +inline std::string Join(const std::string& delimiter, const Container& v) { + std::stringstream s; + int cnt = static_cast(v.size()) - 1; + for (auto i = v.begin(); i != v.end(); ++i, --cnt) { + s << (*i) << (cnt ? delimiter : ""); + } + return std::move(s).str(); +} +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/TypeCast.h b/backends/aoti/slim/c10/util/TypeCast.h new file mode 100644 index 00000000000..cfaaaebec95 --- /dev/null +++ b/backends/aoti/slim/c10/util/TypeCast.h @@ -0,0 +1,236 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +template +struct needs_real { + constexpr static bool value = + (is_complex::value && !is_complex::value); +}; + +template +struct maybe_real { + STANDALONE_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_real { + STANDALONE_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + return src.real(); + } +}; + +template +struct maybe_bool { + STANDALONE_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_bool { + STANDALONE_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + // Don't use bool operator so as to to also compile for ComplexHalf. + return src.real() || src.imag(); + } +}; + +// Note: deliberately ignores undefined behavior, consistent with NumPy. +// PyTorch's type conversions can cause a variety of undefined behavior, +// including float to integral overflow and signed to unsigned integer overflow. +// Some of this undefined behavior is addressed below. +template +struct static_cast_with_inter_type { + STANDALONE_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( + src_t src) { + constexpr bool real = needs_real::value; + auto r = maybe_real::apply(src); + return static_cast(r); + } +}; + +// Partial template specialization for casting to bool. +// Need to handle complex types separately, as we don't +// simply want to cast the real part to bool. +template +struct static_cast_with_inter_type { + STANDALONE_HOST_DEVICE static inline bool apply(src_t src) { + constexpr bool complex = needs_real::value; + return static_cast(maybe_bool::apply(src)); + } +}; + +// Partial template instantiation for casting to uint8. +// Note: Converting from negative float values to unsigned integer types is +// undefined behavior in C++, and current CPU and GPU compilers exhibit +// divergent behavior. Casting from negative float values to signed +// integer types and then to unsigned integer types is not undefined, +// however, so this cast improves the consistency of type conversions +// to uint8 across compilers. +// Further note: Type conversions across compilers still have other undefined +// and divergent behavior. +template +struct static_cast_with_inter_type { + STANDALONE_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast( + static_cast(maybe_real::apply(src))); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::BFloat16> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::BFloat16 src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e5m2> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e5m2 src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e5m2fnuz> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e5m2fnuz src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e4m3fn> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e4m3fn src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e4m3fnuz> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e4m3fnuz src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +// TODO(#146647): Can we make all these template specialization happen +// based off our apply macros? +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e8m0fnu> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e8m0fnu src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Half> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Half src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::complex> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::complex src) { + return static_cast>( + static_cast>(src)); + } +}; + +template +STANDALONE_HOST_DEVICE To convert(From f) { + return static_cast_with_inter_type::apply(f); +} + +// Define separately to avoid being inlined and prevent code-size bloat +[[noreturn]] inline void report_overflow(const char* name) { + std::ostringstream oss; + oss << "value cannot be converted to type " << name << " without overflow"; + throw std::runtime_error(oss.str()); // rather than domain_error (issue 33562) +} + +template +To checked_convert(From f, const char* name) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same_v && + overflows(f, /* strict_unsigned */ !std::is_signed_v)) { + report_overflow(name); + } + return convert(f); +} + +} // namespace standalone::c10 + +STANDALONE_CLANG_DIAGNOSTIC_POP() + +// Trigger tests for D25440771. TODO: Remove this line any time you want. diff --git a/backends/aoti/slim/c10/util/TypeSafeSignMath.h b/backends/aoti/slim/c10/util/TypeSafeSignMath.h new file mode 100644 index 00000000000..276b1cee7d0 --- /dev/null +++ b/backends/aoti/slim/c10/util/TypeSafeSignMath.h @@ -0,0 +1,141 @@ +#pragma once + +#include + +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wstring-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion") +#endif +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Returns false since we cannot have x < 0 if x is unsigned. +template +inline constexpr bool is_negative( + const T& /*x*/, + std::true_type /*is_unsigned*/) { + return false; +} + +/// Returns true if a signed variable x < 0 +template +inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { + return x < T(0); +} + +/// Returns true if x < 0 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, standalone::c10::Half does not :-( +template +inline constexpr bool is_negative(const T& x) { + return is_negative(x, std::is_unsigned()); +} + +/// Returns the sign of an unsigned variable x as 0, 1 +template +inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { + return T(0) < x; +} + +/// Returns the sign of a signed variable x as -1, 0, 1 +template +inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { + return (T(0) < x) - (x < T(0)); +} + +/// Returns the sign of x as -1, 0, 1 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, standalone::c10::Half does not :-( +template +inline constexpr int signum(const T& x) { + return signum(x, std::is_unsigned()); +} + +/// Returns true if a and b are not both negative +template +inline constexpr bool signs_differ(const T& a, const U& b) { + return is_negative(a) != is_negative(b); +} + +// Suppress sign compare warning when compiling with GCC +// as later does not account for short-circuit rule before +// raising the warning, see https://godbolt.org/z/Tr3Msnz99 +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif + +/// Returns true if x is greater than the greatest value of the type Limit +template +inline constexpr bool greater_than_max(const T& x) { + constexpr bool can_overflow = + std::numeric_limits::digits > std::numeric_limits::digits; + return can_overflow && x > std::numeric_limits::max(); +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +/// Returns true if x < lowest(Limit). Standard comparison +template +inline constexpr bool less_than_lowest( + const T& x, + std::false_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < std::numeric_limits::lowest(); +} + +/// Returns false since all the limit is signed and therefore includes +/// negative values but x cannot be negative because it is unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::false_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x < 0, where 0 is constructed from T. +/// Limit is not signed, so its lower value is zero +template +inline constexpr bool less_than_lowest( + const T& x, + std::true_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < T(0); +} + +/// Returns false sign both types are unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::true_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x is less than the lowest value of type T +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, standalone::c10::Half does not : +template +inline constexpr bool less_than_lowest(const T& x) { + return less_than_lowest( + x, std::is_unsigned(), std::is_unsigned()); +} + +} // namespace standalone::c10 + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/accumulate.h b/backends/aoti/slim/c10/util/accumulate.h new file mode 100644 index 00000000000..4972dd9826a --- /dev/null +++ b/backends/aoti/slim/c10/util/accumulate.h @@ -0,0 +1,125 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace standalone::c10 { + +/// Sum of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t sum_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), container.end(), static_cast(0)); +} + +/// Sum of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t sum_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate(begin, end, static_cast(0)); +} + +/// Product of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t multiply_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), + container.end(), + static_cast(1), + std::multiplies<>()); +} + +/// Product of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t multiply_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + begin, end, static_cast(1), std::multiplies<>()); +} + +/// Return product of all dimensions starting from k +/// Returns 1 if k>=dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_from_dim(const int k, const C& dims) { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0); + + if (k > static_cast(dims.size())) { + return 1; + } else { + auto cbegin = dims.cbegin(); + std::advance(cbegin, k); + return multiply_integers(cbegin, dims.cend()); + } +} + +/// Product of all dims up to k (not including dims[k]) +/// Throws an error if k>dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_to_dim(const int k, const C& dims) { + STANDALONE_INTERNAL_ASSERT(0 <= k); + STANDALONE_INTERNAL_ASSERT((unsigned)k <= dims.size()); + + auto cend = dims.cbegin(); + std::advance(cend, k); + return multiply_integers(dims.cbegin(), cend); +} + +/// Product of all dims between k and l (including dims[k] and excluding +/// dims[l]) k and l may be supplied in either order +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_between_dim(int k, int l, const C& dims) { + STANDALONE_INTERNAL_ASSERT(0 <= k); + STANDALONE_INTERNAL_ASSERT(0 <= l); + + if (k > l) { + std::swap(k, l); + } + + STANDALONE_INTERNAL_ASSERT((unsigned)l < dims.size()); + + auto cbegin = dims.cbegin(); + auto cend = dims.cbegin(); + std::advance(cbegin, k); + std::advance(cend, l); + return multiply_integers(cbegin, cend); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/bit_cast.h b/backends/aoti/slim/c10/util/bit_cast.h new file mode 100644 index 00000000000..765ec641486 --- /dev/null +++ b/backends/aoti/slim/c10/util/bit_cast.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#if __has_include() && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L) +#include +#define STANDALONE_HAVE_STD_BIT_CAST 1 +#else +#define STANDALONE_HAVE_STD_BIT_CAST 0 +#endif // __has_include() && (__cplusplus >= 202002L || + // (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)) + +namespace standalone::c10 { + +#if STANDALONE_HAVE_STD_BIT_CAST +using std::bit_cast; +#else +// Implementations of std::bit_cast() from C++ 20. +// +// This is a less sketchy version of reinterpret_cast. +// +// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more +// information as well as the source of our implementations. +template +std::enable_if_t< + sizeof(To) == sizeof(From) && std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> +// constexpr support needs compiler magic +bit_cast(const From& src) noexcept { + static_assert( + std::is_trivially_constructible_v, + "This implementation additionally requires " + "destination type to be trivially constructible"); + + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} +#endif // STANDALONE_HAVE_STD_BIT_CAST +#undef STANDALONE_HAVE_STD_BIT_CAST + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/bits.h b/backends/aoti/slim/c10/util/bits.h new file mode 100644 index 00000000000..2d365463a01 --- /dev/null +++ b/backends/aoti/slim/c10/util/bits.h @@ -0,0 +1,61 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits1x8 { + using underlying = uint8_t; + uint8_t val_; + bits1x8() = default; + STANDALONE_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} +}; + +/** + * bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits2x4 { + using underlying = uint8_t; + uint8_t val_; + bits2x4() = default; + STANDALONE_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} +}; + +/** + * bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits4x2 { + using underlying = uint8_t; + uint8_t val_; + bits4x2() = default; + STANDALONE_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} +}; + +/** + * bits8 is an uninterpreted dtype of a tensor with 8 bits, without any + * semantics defined. + */ +struct alignas(1) bits8 { + uint8_t val_; + bits8() = default; + STANDALONE_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} +}; + +/** + * bits16 is an uninterpreted dtype of a tensor with 16 bits, without any + * semantics defined. + */ +struct alignas(2) bits16 { + uint16_t val_; + bits16() = default; + STANDALONE_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/complex.h b/backends/aoti/slim/c10/util/complex.h new file mode 100644 index 00000000000..988e446b3e4 --- /dev/null +++ b/backends/aoti/slim/c10/util/complex.h @@ -0,0 +1,690 @@ +#pragma once + +#include + +#include +#include + +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#endif + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if STANDALONE_CLANG_HAS_WARNING("-Wfloat-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") +#endif + +namespace standalone::c10 { + +// standalone::c10::complex is an implementation of complex numbers that aims +// to work on all devices supported by PyTorch +// +// Most of the APIs duplicates std::complex +// Reference: https://en.cppreference.com/w/cpp/numeric/complex +// +// [NOTE: Complex Operator Unification] +// Operators currently use a mix of std::complex, thrust::complex, and +// standalone::c10::complex internally. The end state is that all operators +// will use standalone::c10::complex internally. Until then, there may be +// some hacks to support all variants. +// +// +// [Note on Constructors] +// +// The APIs of constructors are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/complex +// +// Since C++14, all constructors are constexpr in std::complex +// +// There are three types of constructors: +// - initializing from real and imag: +// `constexpr complex( const T& re = T(), const T& im = T() );` +// - implicitly-declared copy constructor +// - converting constructors +// +// Converting constructors: +// - std::complex defines converting constructor between float/double/long +// double, +// while we define converting constructor between float/double. +// - For these converting constructors, upcasting is implicit, downcasting is +// explicit. +// - We also define explicit casting from std::complex/thrust::complex +// - Note that the conversion from thrust is not constexpr, because +// thrust does not define them as constexpr ???? +// +// +// [Operator =] +// +// The APIs of operator = are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D +// +// Since C++20, all operator= are constexpr. Although we are not building with +// C++20, we also obey this behavior. +// +// There are three types of assign operator: +// - Assign a real value from the same scalar type +// - In std, this is templated as complex& operator=(const T& x) +// with specialization `complex& operator=(T x)` for float/double/long +// double Since we only support float and double, on will use `complex& +// operator=(T x)` +// - Copy assignment operator and converting assignment operator +// - There is no specialization of converting assignment operators, which type +// is +// convertible is solely dependent on whether the scalar type is convertible +// +// In addition to the standard assignment, we also provide assignment operators +// with std and thrust +// +// +// [Casting operators] +// +// std::complex does not have casting operators. We define casting operators +// casting to std::complex and thrust::complex +// +// +// [Operator ""] +// +// std::complex has custom literals `i`, `if` and `il` defined in namespace +// `std::literals::complex_literals`. We define our own custom literals in the +// namespace `standalone::c10::complex_literals`. Our custom literals does not +// follow the same behavior as in std::complex, instead, we define _if, _id to +// construct float/double complex literals. +// +// +// [real() and imag()] +// +// In C++20, there are two overload of these functions, one it to return the +// real/imag, another is to set real/imag, they are both constexpr. We follow +// this design. +// +// +// [Operator +=,-=,*=,/=] +// +// Since C++20, these operators become constexpr. In our implementation, they +// are also constexpr. +// +// There are two types of such operators: operating with a real number, or +// operating with another complex number. For the operating with a real number, +// the generic template form has argument type `const T &`, while the overload +// for float/double/long double has `T`. We will follow the same type as +// float/double/long double in std. +// +// [Unary operator +-] +// +// Since C++20, they are constexpr. We also make them expr +// +// [Binary operators +-*/] +// +// Each operator has three versions (taking + as example): +// - complex + complex +// - complex + real +// - real + complex +// +// [Operator ==, !=] +// +// Each operator has three versions (taking == as example): +// - complex == complex +// - complex == real +// - real == complex +// +// Some of them are removed on C++20, but we decide to keep them +// +// [Operator <<, >>] +// +// These are implemented by casting to std::complex +// +// +// +// TODO(@zasdfgbnm): standalone::c10::complex is not +// currently supported, because: +// - lots of members and functions of standalone::c10::Half are not constexpr +// - thrust::complex only support float and double + +template +struct alignas(sizeof(T) * 2) complex { + using value_type = T; + + T real_ = T(0); + T imag_ = T(0); + + constexpr complex() = default; + STANDALONE_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) + : real_(re), imag_(im) {} + template + explicit constexpr complex(const std::complex& other) + : complex(other.real(), other.imag()) {} +#if defined(__CUDACC__) || defined(__HIPCC__) + template + explicit STANDALONE_HOST_DEVICE complex(const thrust::complex& other) + : real_(other.real()), imag_(other.imag()) {} +// NOTE can not be implemented as follow due to ROCm bug: +// explicit STANDALONE_HOST_DEVICE complex(const thrust::complex &other): +// complex(other.real(), other.imag()) {} +#endif + + // Use SFINAE to specialize casting constructor for + // standalone::c10::complex and standalone::c10::complex + template + STANDALONE_HOST_DEVICE explicit constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + template + STANDALONE_HOST_DEVICE constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + + constexpr complex& operator=(T re) { + real_ = re; + imag_ = 0; + return *this; + } + + constexpr complex& operator+=(T re) { + real_ += re; + return *this; + } + + constexpr complex& operator-=(T re) { + real_ -= re; + return *this; + } + + constexpr complex& operator*=(T re) { + real_ *= re; + imag_ *= re; + return *this; + } + + constexpr complex& operator/=(T re) { + real_ /= re; + imag_ /= re; + return *this; + } + + template + constexpr complex& operator=(const complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + + template + constexpr complex& operator+=(const complex& rhs) { + real_ += rhs.real(); + imag_ += rhs.imag(); + return *this; + } + + template + constexpr complex& operator-=(const complex& rhs) { + real_ -= rhs.real(); + imag_ -= rhs.imag(); + return *this; + } + + template + constexpr complex& operator*=(const complex& rhs) { + // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } + +#ifdef __APPLE__ +#define FORCE_INLINE_APPLE __attribute__((always_inline)) +#else +#define FORCE_INLINE_APPLE +#endif + template + constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) + __ubsan_ignore_float_divide_by_zero__ { + // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i + // the calculation below follows numpy's complex division + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + +#if defined(__GNUC__) && !defined(__clang__) + // std::abs is already constexpr by gcc + auto abs_c = std::abs(c); + auto abs_d = std::abs(d); +#else + auto abs_c = c < 0 ? -c : c; + auto abs_d = d < 0 ? -d : d; +#endif + + if (abs_c >= abs_d) { + if (abs_c == U(0) && abs_d == U(0)) { + /* divide by zeros should yield a complex inf or nan */ + real_ = a / abs_c; + imag_ = b / abs_d; + } else { + auto rat = d / c; + auto scl = U(1.0) / (c + d * rat); + real_ = (a + b * rat) * scl; + imag_ = (b - a * rat) * scl; + } + } else { + auto rat = c / d; + auto scl = U(1.0) / (d + c * rat); + real_ = (a * rat + b) * scl; + imag_ = (b * rat - a) * scl; + } + return *this; + } +#undef FORCE_INLINE_APPLE + + template + constexpr complex& operator=(const std::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + STANDALONE_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } +#endif + + template + explicit constexpr operator std::complex() const { + return std::complex(std::complex(real(), imag())); + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + STANDALONE_HOST_DEVICE explicit operator thrust::complex() const { + return static_cast>(thrust::complex(real(), imag())); + } +#endif + + // consistent with NumPy behavior + explicit constexpr operator bool() const { + return real() || imag(); + } + + STANDALONE_HOST_DEVICE constexpr T real() const { + return real_; + } + constexpr void real(T value) { + real_ = value; + } + STANDALONE_HOST_DEVICE constexpr T imag() const { + return imag_; + } + constexpr void imag(T value) { + imag_ = value; + } +}; + +namespace complex_literals { + +constexpr complex operator""_if(long double imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(long double imag) { + return complex(0.0, static_cast(imag)); +} + +constexpr complex operator""_if(unsigned long long imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(unsigned long long imag) { + return complex(0.0, static_cast(imag)); +} + +} // namespace complex_literals + +template +constexpr complex operator+(const complex& val) { + return val; +} + +template +constexpr complex operator-(const complex& val) { + return complex(-val.real(), -val.imag()); +} + +template +constexpr complex operator+(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const complex& lhs, const T& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const T& lhs, const complex& rhs) { + return complex(lhs + rhs.real(), rhs.imag()); +} + +template +constexpr complex operator-(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const complex& lhs, const T& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const T& lhs, const complex& rhs) { + complex result = -rhs; + return result += lhs; +} + +template +constexpr complex operator*(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const complex& lhs, const T& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const T& lhs, const complex& rhs) { + complex result = rhs; + return result *= lhs; +} + +template +constexpr complex operator/(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const complex& lhs, const T& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const T& lhs, const complex& rhs) { + complex result(lhs, T()); + return result /= rhs; +} + +// Define operators between integral scalars and standalone::c10::complex. +// std::complex does not support this when T is a floating-point number. This is +// useful because it saves a lot of "static_cast" when operate a complex and an +// integer. This makes the code both less verbose and potentially more +// efficient. +#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ + typename std::enable_if_t< \ + std::is_floating_point_v && std::is_integral_v, \ + int> = 0 + +template +constexpr standalone::c10::complex operator+( + const standalone::c10::complex& a, + const iT& b) { + return a + static_cast(b); +} + +template +constexpr standalone::c10::complex operator+( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) + b; +} + +template +constexpr standalone::c10::complex operator-( + const standalone::c10::complex& a, + const iT& b) { + return a - static_cast(b); +} + +template +constexpr standalone::c10::complex operator-( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) - b; +} + +template +constexpr standalone::c10::complex operator*( + const standalone::c10::complex& a, + const iT& b) { + return a * static_cast(b); +} + +template +constexpr standalone::c10::complex operator*( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) * b; +} + +template +constexpr standalone::c10::complex operator/( + const standalone::c10::complex& a, + const iT& b) { + return a / static_cast(b); +} + +template +constexpr standalone::c10::complex operator/( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) / b; +} + +#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION + +template +constexpr bool operator==(const complex& lhs, const complex& rhs) { + return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); +} + +template +constexpr bool operator==(const complex& lhs, const T& rhs) { + return (lhs.real() == rhs) && (lhs.imag() == T()); +} + +template +constexpr bool operator==(const T& lhs, const complex& rhs) { + return (lhs == rhs.real()) && (T() == rhs.imag()); +} + +template +constexpr bool operator!=(const complex& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const complex& lhs, const T& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const T& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const complex& x) { + return (os << static_cast>(x)); +} + +template +std::basic_istream& operator>>( + std::basic_istream& is, + complex& x) { + std::complex tmp; + is >> tmp; + x = tmp; + return is; +} + +} // namespace standalone::c10 + +// std functions +// +// The implementation of these functions also follow the design of C++20 + +namespace std { + +template +constexpr T real(const standalone::c10::complex& z) { + return z.real(); +} + +template +constexpr T imag(const standalone::c10::complex& z) { + return z.imag(); +} + +template +STANDALONE_HOST_DEVICE T abs(const standalone::c10::complex& z) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return thrust::abs(static_cast>(z)); +#else + return std::abs(static_cast>(z)); +#endif +} + +#if defined(USE_ROCM) +#define ROCm_Bug(x) +#else +#define ROCm_Bug(x) x +#endif + +template +STANDALONE_HOST_DEVICE T arg(const standalone::c10::complex& z) { + return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); +} + +#undef ROCm_Bug + +template +constexpr T norm(const standalone::c10::complex& z) { + return z.real() * z.real() + z.imag() * z.imag(); +} + +// For std::conj, there are other versions of it: +// constexpr std::complex conj( float z ); +// template< class DoubleOrInteger > +// constexpr std::complex conj( DoubleOrInteger z ); +// constexpr std::complex conj( long double z ); +// These are not implemented +// TODO(@zasdfgbnm): implement them as standalone::c10::conj +template +constexpr standalone::c10::complex conj( + const standalone::c10::complex& z) { + return standalone::c10::complex(z.real(), -z.imag()); +} + +// Thrust does not have complex --> complex version of thrust::proj, +// so this function is not implemented at standalone right now. +// TODO(@zasdfgbnm): implement it by ourselves + +// There is no standalone version of std::polar, because std::polar always +// returns std::complex. Use standalone::c10::polar instead; + +} // namespace std + +namespace standalone::c10 { + +template +STANDALONE_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::polar(r, theta)); +#else + // std::polar() requires r >= 0, so spell out the explicit implementation to + // avoid a branch. + return complex(r * std::cos(theta), r * std::sin(theta)); +#endif +} + +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + STANDALONE_HOST_DEVICE explicit inline complex( + const Half& real, + const Half& imag) + : real_(real), imag_(imag) {} + STANDALONE_HOST_DEVICE inline complex( + const standalone::c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline STANDALONE_HOST_DEVICE operator standalone::c10::complex() + const { + return {real_, imag_}; + } + + constexpr STANDALONE_HOST_DEVICE Half real() const { + return real_; + } + constexpr STANDALONE_HOST_DEVICE Half imag() const { + return imag_; + } + + STANDALONE_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + STANDALONE_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + STANDALONE_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + +} // namespace standalone::c10 + +STANDALONE_CLANG_DIAGNOSTIC_POP() + +#define STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H +// math functions are included in a separate file +#include // IWYU pragma: keep +// utilities for complex types +#include // IWYU pragma: keep +#undef STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H diff --git a/backends/aoti/slim/c10/util/complex_math.h b/backends/aoti/slim/c10/util/complex_math.h new file mode 100644 index 00000000000..56fc84fe90b --- /dev/null +++ b/backends/aoti/slim/c10/util/complex_math.h @@ -0,0 +1,500 @@ +#if !defined(STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "standalone/c10/util/complex_math.h is not meant to be individually included. Include standalone/c10/util/complex.h instead." +#endif + +#include + +namespace standalone::c10::complex_math { + +// Exponential functions + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex exp( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::exp(static_cast>(x))); +#else + return static_cast>( + std::exp(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log(static_cast>(x))); +#else + return static_cast>( + std::log(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log10( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log10(static_cast>(x))); +#else + return static_cast>( + std::log10(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log2( + const standalone::c10::complex& x) { + const standalone::c10::complex log2 = + standalone::c10::complex(::log(2.0), 0.0); + return standalone::c10::complex_math::log(x) / log2; +} + +// Power functions +// +#if defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) +namespace _detail { +template +standalone::c10::complex compute_csqrt( + const standalone::c10::complex& z) { + constexpr auto half = T(.5); + + // Trust standard library to correctly handle infs and NaNs + if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) || + std::isnan(z.imag())) { + return static_cast>( + std::sqrt(static_cast>(z))); + } + + // Special case for square root of pure imaginary values + if (z.real() == T(0)) { + if (z.imag() == T(0)) { + return standalone::c10::complex(T(0), z.imag()); + } + auto v = std::sqrt(half * std::abs(z.imag())); + return standalone::c10::complex(v, std::copysign(v, z.imag())); + } + + // At this point, z is non-zero and finite + if (z.real() >= 0.0) { + auto t = std::sqrt((z.real() + std::abs(z)) * half); + return standalone::c10::complex(t, half * (z.imag() / t)); + } + + auto t = std::sqrt((-z.real() + std::abs(z)) * half); + return standalone::c10::complex( + half * std::abs(z.imag() / t), std::copysign(t, z.imag())); +} + +// Compute complex arccosine using formula from W. Kahan +// "Branch Cuts for Complex Elementary Functions" 1986 paper: +// cacos(z).re = 2*atan2(sqrt(1-z).re(), sqrt(1+z).re()) +// cacos(z).im = asinh((sqrt(conj(1+z))*sqrt(1-z)).im()) +template +standalone::c10::complex compute_cacos( + const standalone::c10::complex& z) { + auto constexpr one = T(1); + // Trust standard library to correctly handle infs and NaNs + if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) || + std::isnan(z.imag())) { + return static_cast>( + std::acos(static_cast>(z))); + } + auto a = + compute_csqrt(standalone::c10::complex(one - z.real(), -z.imag())); + auto b = compute_csqrt(standalone::c10::complex(one + z.real(), z.imag())); + auto c = + compute_csqrt(standalone::c10::complex(one + z.real(), -z.imag())); + auto r = T(2) * std::atan2(a.real(), b.real()); + // Explicitly unroll (a*c).imag() + auto i = std::asinh(a.real() * c.imag() + a.imag() * c.real()); + return standalone::c10::complex(r, i); +} + +inline standalone::c10::complex sqrt( + const standalone::c10::complex& in) { + return compute_csqrt(in); +} + +inline standalone::c10::complex sqrt( + const standalone::c10::complex& in) { + return compute_csqrt(in); +} + +inline standalone::c10::complex acos( + const standalone::c10::complex& in) { + return compute_cacos(in); +} + +inline standalone::c10::complex acos( + const standalone::c10::complex& in) { + return compute_cacos(in); +} +} // namespace _detail +#endif + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex sqrt( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sqrt(static_cast>(x))); +#elif !( \ + defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) + return static_cast>( + std::sqrt(static_cast>(x))); +#else + return _detail::sqrt(x); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const T& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const T& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const U& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const T& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +// Trigonometric functions + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex sin( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sin(static_cast>(x))); +#else + return static_cast>( + std::sin(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex cos( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cos(static_cast>(x))); +#else + return static_cast>( + std::cos(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex tan( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tan(static_cast>(x))); +#else + return static_cast>( + std::tan(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex asin( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asin(static_cast>(x))); +#else + return static_cast>( + std::asin(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex acos( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acos(static_cast>(x))); +#elif !defined(_LIBCPP_VERSION) + return static_cast>( + std::acos(static_cast>(x))); +#else + return _detail::acos(x); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex atan( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atan(static_cast>(x))); +#else + return static_cast>( + std::atan(static_cast>(x))); +#endif +} + +// Hyperbolic functions + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex sinh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sinh(static_cast>(x))); +#else + return static_cast>( + std::sinh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex cosh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cosh(static_cast>(x))); +#else + return static_cast>( + std::cosh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex tanh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tanh(static_cast>(x))); +#else + return static_cast>( + std::tanh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex asinh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asinh(static_cast>(x))); +#else + return static_cast>( + std::asinh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex acosh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acosh(static_cast>(x))); +#else + return static_cast>( + std::acosh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex atanh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atanh(static_cast>(x))); +#else + return static_cast>( + std::atanh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log1p( + const standalone::c10::complex& z) { +#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \ + defined(__HIPCC__) + // For Mac, the new implementation yielded a high relative error. Falling back + // to the old version for now. + // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + // For CUDA we also use this one, as thrust::log(thrust::complex) takes + // *forever* to compile + + // log1p(z) = log(1 + z) + // Let's define 1 + z = r * e ^ (i * a), then we have + // log(r * e ^ (i * a)) = log(r) + i * a + // With z = x + iy, the term r can be written as + // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 + // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 + // So, log(r) is + // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) + // = 0.5 * log1p(x * (x + 2) + y ^ 2) + // we need to use the expression only on certain condition to avoid overflow + // and underflow from `(x * (x + 2) + y ^ 2)` + T x = z.real(); + T y = z.imag(); + T zabs = std::abs(z); + T theta = std::atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {T(0.5) * std::log1p(r), theta}; + } else { + T z0 = std::hypot(x + 1, y); + return {std::log(z0), theta}; + } +#else + // CPU path + // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + standalone::c10::complex u = z + T(1); + if (u == T(1)) { + return z; + } else { + auto log_u = log(u); + if (u - T(1) == z) { + return log_u; + } + return log_u * (z / (u - T(1))); + } +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex expm1( + const standalone::c10::complex& z) { + // expm1(z) = exp(z) - 1 + // Define z = x + i * y + // f = e ^ (x + i * y) - 1 + // = e ^ x * e ^ (i * y) - 1 + // = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y)) + // = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y) + // = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y) + T x = z.real(); + T y = z.imag(); + T a = std::sin(y / 2); + T er = std::expm1(x) * std::cos(y) - T(2) * a * a; + T ei = std::exp(x) * std::sin(y); + return {er, ei}; +} + +} // namespace standalone::c10::complex_math + +using standalone::c10::complex_math::acos; +using standalone::c10::complex_math::acosh; +using standalone::c10::complex_math::asin; +using standalone::c10::complex_math::asinh; +using standalone::c10::complex_math::atan; +using standalone::c10::complex_math::atanh; +using standalone::c10::complex_math::cos; +using standalone::c10::complex_math::cosh; +using standalone::c10::complex_math::exp; +using standalone::c10::complex_math::expm1; +using standalone::c10::complex_math::log; +using standalone::c10::complex_math::log10; +using standalone::c10::complex_math::log1p; +using standalone::c10::complex_math::log2; +using standalone::c10::complex_math::pow; +using standalone::c10::complex_math::sin; +using standalone::c10::complex_math::sinh; +using standalone::c10::complex_math::sqrt; +using standalone::c10::complex_math::tan; +using standalone::c10::complex_math::tanh; + +namespace std { + +using standalone::c10::complex_math::acos; +using standalone::c10::complex_math::acosh; +using standalone::c10::complex_math::asin; +using standalone::c10::complex_math::asinh; +using standalone::c10::complex_math::atan; +using standalone::c10::complex_math::atanh; +using standalone::c10::complex_math::cos; +using standalone::c10::complex_math::cosh; +using standalone::c10::complex_math::exp; +using standalone::c10::complex_math::expm1; +using standalone::c10::complex_math::log; +using standalone::c10::complex_math::log10; +using standalone::c10::complex_math::log1p; +using standalone::c10::complex_math::log2; +using standalone::c10::complex_math::pow; +using standalone::c10::complex_math::sin; +using standalone::c10::complex_math::sinh; +using standalone::c10::complex_math::sqrt; +using standalone::c10::complex_math::tan; +using standalone::c10::complex_math::tanh; + +} // namespace std diff --git a/backends/aoti/slim/c10/util/complex_utils.h b/backends/aoti/slim/c10/util/complex_utils.h new file mode 100644 index 00000000000..5b29406a186 --- /dev/null +++ b/backends/aoti/slim/c10/util/complex_utils.h @@ -0,0 +1,46 @@ +#if !defined(STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "standalone/c10/util/complex_utils.h is not meant to be individually included. Include standalone/c10/util/complex.h instead." +#endif + +#include + +namespace standalone::c10 { + +template +struct is_complex : public std::false_type {}; + +template +struct is_complex> : public std::true_type {}; + +template +struct is_complex> : public std::true_type {}; + +// Extract double from std::complex; is identity otherwise +// TODO: Write in more idiomatic C++17 +template +struct scalar_value_type { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; + +} // namespace standalone::c10 + +namespace std { + +template +class numeric_limits> : public numeric_limits {}; + +template +bool isnan(const standalone::c10::complex& v) { + return std::isnan(v.real()) || std::isnan(v.imag()); +} + +} // namespace std diff --git a/backends/aoti/slim/c10/util/copysign.h b/backends/aoti/slim/c10/util/copysign.h new file mode 100644 index 00000000000..1012934049c --- /dev/null +++ b/backends/aoti/slim/c10/util/copysign.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +namespace standalone::c10 { + +// Note: Explicit implementation of copysign for Half and BFloat16 +// is needed to workaround g++-7/8 crash on aarch64, but also makes +// copysign faster for the half-precision types +template +inline auto copysign(const T& a, const U& b) { + return std::copysign(a, b); +} + +// Implement copysign for half precision floats using bit ops +// Sign is the most significant bit for both half and bfloat16 types +inline Half copysign(Half a, Half b) { + return Half((a.x & 0x7fff) | (b.x & 0x8000), Half::from_bits()); +} + +inline BFloat16 copysign(BFloat16 a, BFloat16 b) { + return BFloat16((a.x & 0x7fff) | (b.x & 0x8000), BFloat16::from_bits()); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/floating_point_utils.h b/backends/aoti/slim/c10/util/floating_point_utils.h new file mode 100644 index 00000000000..259cb93b0a5 --- /dev/null +++ b/backends/aoti/slim/c10/util/floating_point_utils.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +namespace standalone::c10::detail { + +STANDALONE_HOST_DEVICE inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#else + return standalone::c10::bit_cast(w); +#endif +} + +STANDALONE_HOST_DEVICE inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + return standalone::c10::bit_cast(f); +#endif +} + +} // namespace standalone::c10::detail diff --git a/backends/aoti/slim/c10/util/generic_math.h b/backends/aoti/slim/c10/util/generic_math.h new file mode 100644 index 00000000000..00bb4265d9d --- /dev/null +++ b/backends/aoti/slim/c10/util/generic_math.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include + +#if defined(__CUDA_ARCH__) +#include +#define STANDALONE_COMPAT_COPYSIGN standalone::c10::cuda::compat::copysign +// TODO: rocm is not supported yet +// #elif defined(__HIPCC__) +// #include +// #define STANDALONE_COMPAT_COPYSIGN standalone::c10::hip::compat::copysign +#else +#include +#define STANDALONE_COMPAT_COPYSIGN standalone::c10::copysign +#endif + +// The functions in this file should be header-only as it is used under +// ABI-compatibility mode. + +namespace standalone::c10 { + +// NOTE: [Floor Division in Python] +// Python's __floordiv__ operator is more complicated than just floor(a / b). +// It aims to maintain the property: a == (a // b) * b + remainder(a, b) +// which can otherwise fail due to rounding errors in the remainder. +// So, instead it is calculated as: a // b = (a - remainder(a, b)) / b +// With some additional fix-ups added to the result. +// +// For reference, see CPython's implementation: +// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + +template +inline STANDALONE_HOST_DEVICE scalar_t div_floor_floating( + scalar_t a, + scalar_t b) __ubsan_ignore_float_divide_by_zero__ { + if (STANDALONE_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return a / b; + } + + auto mod = std::fmod(a, b); + auto div = (a - mod) / b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = STANDALONE_COMPAT_COPYSIGN(scalar_t(0), a / b); + } + return floordiv; +} + +template +inline STANDALONE_HOST_DEVICE scalar_t +div_floor_integer(scalar_t a, scalar_t b) { + if (standalone::c10::signs_differ(a, b)) { + // Subtracts one from the results of truncation division if the + // divisor and dividend have different sign(bit)s and the remainder of + // the division is nonzero + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + return a / b; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline STANDALONE_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) + __ubsan_ignore_float_divide_by_zero__ { + if (STANDALONE_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return std::fmod(a, b); + } + + auto mod = std::fmod(a, b); + if (mod == 0) { + mod = STANDALONE_COMPAT_COPYSIGN(scalar_t(0), b); + } else if ((b < 0) != (mod < 0)) { + mod += b; + } + return mod; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline STANDALONE_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) { + auto mod = a % b; + if (mod != 0 && (b > 0) != (mod > 0)) { + mod += b; + } + return mod; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/irange.h b/backends/aoti/slim/c10/util/irange.h new file mode 100644 index 00000000000..0d10f373a04 --- /dev/null +++ b/backends/aoti/slim/c10/util/irange.h @@ -0,0 +1,123 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace standalone::c10 { + +namespace detail { + +template < + typename I, + bool one_sided = false, + std::enable_if_t, int> = 0> +struct integer_iterator { + using iterator_category = std::input_iterator_tag; + using value_type = I; + using difference_type = std::ptrdiff_t; + using pointer = I*; + using reference = I&; + + explicit constexpr integer_iterator(I value) : value(value) {} + + constexpr I operator*() const { + return value; + } + + constexpr I const* operator->() const { + return &value; + } + + constexpr integer_iterator& operator++() { + ++value; + return *this; + } + + constexpr integer_iterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + constexpr bool operator==(const integer_iterator& other) const { + if constexpr (one_sided) { + // Range-for loops' end test is `begin != end`, not `begin < + // end`. To handle `standalone::c10::irange(n)` where n < 0 (which + // should be empty), we just make `begin != end` fail whenever `end` is + // negative. + return is_negative(other.value) || value == other.value; + } else { + return value == other.value; + } + // Suppress "warning: missing return statement at end of non-void function" + // which Nvidia's Robert Crovella confirms is an NVCC compiler error + // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 + // `__builtin_unreachable();` would be best here, but it's not + // available with all compilers. So we instead return an arbitrary + // value trusting that this line will, in fact, never be reached. + return false; // Horrible hack + } + + constexpr bool operator!=(const integer_iterator& other) const { + return !(*this == other); + } + + protected: + I value; +}; + +} // namespace detail + +template < + typename I, + bool one_sided = false, + std::enable_if_t, bool> = true> +struct integer_range { + public: + constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {} + using iterator = detail::integer_iterator; + constexpr iterator begin() const { + return begin_; + } + constexpr iterator end() const { + return end_; + } + + private: + iterator begin_; + iterator end_; +}; + +/// Creates an integer range for the half-open interval [begin, end) +/// If end<=begin, then the range is empty. +/// The range has the type of the `end` integer; `begin` integer is +/// cast to this type. +template < + typename Integer1, + typename Integer2, + std::enable_if_t, bool> = true, + std::enable_if_t, bool> = true> +constexpr integer_range irange(Integer1 begin, Integer2 end) { + // If end<=begin then the range is empty; we can achieve this effect by + // choosing the larger of {begin, end} as the loop terminator + return { + static_cast(begin), + std::max(static_cast(begin), end)}; +} + +/// Creates an integer range for the half-open interval [0, end) +/// If end<=begin, then the range is empty +template < + typename Integer, + std::enable_if_t, bool> = true> +constexpr integer_range irange(Integer end) { + return {Integer(), end}; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/llvmMathExtras.h b/backends/aoti/slim/c10/util/llvmMathExtras.h new file mode 100644 index 00000000000..0b4f92c44c6 --- /dev/null +++ b/backends/aoti/slim/c10/util/llvmMathExtras.h @@ -0,0 +1,899 @@ +//===-- llvm/Support/MathExtras.h - Useful math functions -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some functions that are useful for math stuff. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __ANDROID_NDK__ +#include +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#ifndef LLVM_GNUC_PREREQ +#if defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) +#define LLVM_GNUC_PREREQ(maj, min, patch) \ + ((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) + __GNUC_PATCHLEVEL__ >= \ + ((maj) << 20) + ((min) << 10) + (patch)) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) +#define LLVM_GNUC_PREREQ(maj, min, patch) \ + ((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) >= ((maj) << 20) + ((min) << 10)) +#else +#define LLVM_GNUC_PREREQ(maj, min, patch) 0 +#endif +#endif + +#ifdef _MSC_VER +// Declare these intrinsics manually rather including intrin.h. It's very +// expensive, and MathExtras.h is popular. +// #include +extern "C" { +unsigned char _BitScanForward(unsigned long* _Index, unsigned long _Mask); +unsigned char _BitScanForward64(unsigned long* _Index, unsigned __int64 _Mask); +unsigned char _BitScanReverse(unsigned long* _Index, unsigned long _Mask); +unsigned char _BitScanReverse64(unsigned long* _Index, unsigned __int64 _Mask); +} +#endif + +namespace standalone::c10::llvm { +/// The behavior an operation has on an input of 0. +enum ZeroBehavior { + /// The returned value is undefined. + ZB_Undefined, + /// The returned value is numeric_limits::max() + ZB_Max, + /// The returned value is numeric_limits::digits + ZB_Width +}; + +namespace detail { +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior) { + if (!Val) + return std::numeric_limits::digits; + if (Val & 0x1) + return 0; + + // Bisection method. + std::size_t ZeroBits = 0; + T Shift = std::numeric_limits::digits >> 1; + T Mask = std::numeric_limits::max() >> Shift; + while (Shift) { + if ((Val & Mask) == 0) { + Val >>= Shift; + ZeroBits |= Shift; + } + Shift >>= 1; + Mask >>= Shift; + } + return ZeroBits; + } +}; + +#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER) +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 32; + +#if __has_builtin(__builtin_ctz) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_ctz(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanForward(&Index, Val); + return Index; +#endif + } +}; + +#if !defined(_MSC_VER) || defined(_M_X64) +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 64; + +#if __has_builtin(__builtin_ctzll) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_ctzll(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanForward64(&Index, Val); + return Index; +#endif + } +}; +#endif +#endif +} // namespace detail + +/// Count number of 0's from the least significant bit to the most +/// stopping at the first 1. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are +/// valid arguments. +template +std::size_t countTrailingZeros(T Val, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return llvm::detail::TrailingZerosCounter::count(Val, ZB); +} + +namespace detail { +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior) { + if (!Val) + return std::numeric_limits::digits; + + // Bisection method. + std::size_t ZeroBits = 0; + for (T Shift = std::numeric_limits::digits >> 1; Shift; Shift >>= 1) { + T Tmp = Val >> Shift; + if (Tmp) + Val = Tmp; + else + ZeroBits |= Shift; + } + return ZeroBits; + } +}; + +#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER) +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 32; + +#if __has_builtin(__builtin_clz) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_clz(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanReverse(&Index, Val); + return Index ^ 31; +#endif + } +}; + +#if !defined(_MSC_VER) || defined(_M_X64) +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 64; + +#if __has_builtin(__builtin_clzll) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_clzll(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanReverse64(&Index, Val); + return Index ^ 63; +#endif + } +}; +#endif +#endif +} // namespace detail + +/// Count number of 0's from the most significant bit to the least +/// stopping at the first 1. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are +/// valid arguments. +template +std::size_t countLeadingZeros(T Val, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return llvm::detail::LeadingZerosCounter::count(Val, ZB); +} + +/// Get the index of the first set bit starting from the least +/// significant bit. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are +/// valid arguments. +template +T findFirstSet(T Val, ZeroBehavior ZB = ZB_Max) { + if (ZB == ZB_Max && Val == 0) + return std::numeric_limits::max(); + + return countTrailingZeros(Val, ZB_Undefined); +} + +/// Create a bitmask with the N right-most bits set to 1, and all other +/// bits set to 0. Only unsigned types are allowed. +template +T maskTrailingOnes(unsigned N) { + static_assert(std::is_unsigned_v, "Invalid type!"); + const unsigned Bits = CHAR_BIT * sizeof(T); + assert(N <= Bits && "Invalid bit index"); + return N == 0 ? 0 : (T(-1) >> (Bits - N)); +} + +/// Create a bitmask with the N left-most bits set to 1, and all other +/// bits set to 0. Only unsigned types are allowed. +template +T maskLeadingOnes(unsigned N) { + return ~maskTrailingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Create a bitmask with the N right-most bits set to 0, and all other +/// bits set to 1. Only unsigned types are allowed. +template +T maskTrailingZeros(unsigned N) { + return maskLeadingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Create a bitmask with the N left-most bits set to 0, and all other +/// bits set to 1. Only unsigned types are allowed. +template +T maskLeadingZeros(unsigned N) { + return maskTrailingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Get the index of the last set bit starting from the least +/// significant bit. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are +/// valid arguments. +template +T findLastSet(T Val, ZeroBehavior ZB = ZB_Max) { + if (ZB == ZB_Max && Val == 0) + return std::numeric_limits::max(); + + // Use ^ instead of - because both gcc and llvm can remove the associated ^ + // in the __builtin_clz intrinsic on x86. + return countLeadingZeros(Val, ZB_Undefined) ^ + (std::numeric_limits::digits - 1); +} + +/// Macro compressed bit reversal table for 256 bits. +/// +/// http://graphics.stanford.edu/~seander/bithacks.html#BitReverseTable +/// NOLINTNEXTLINE(*c-arrays*) +static constexpr unsigned char BitReverseTable256[256] = { +#define R2(n) n, n + 2 * 64, n + 1 * 64, n + 3 * 64 +#define R4(n) R2(n), R2(n + 2 * 16), R2(n + 1 * 16), R2(n + 3 * 16) +#define R6(n) R4(n), R4(n + 2 * 4), R4(n + 1 * 4), R4(n + 3 * 4) + R6(0), + R6(2), + R6(1), + R6(3) +#undef R2 +#undef R4 +#undef R6 +}; + +/// Reverse the bits in \p Val. +template +T reverseBits(T Val) { + // NOLINTNEXTLINE(*c-arrays*) + unsigned char in[sizeof(Val)]; + // NOLINTNEXTLINE(*c-arrays*) + unsigned char out[sizeof(Val)]; + std::memcpy(in, &Val, sizeof(Val)); + for (unsigned i = 0; i < sizeof(Val); ++i) + out[(sizeof(Val) - i) - 1] = BitReverseTable256[in[i]]; + std::memcpy(&Val, out, sizeof(Val)); + return Val; +} + +// NOTE: The following support functions use the _32/_64 extensions instead of +// type overloading so that signed and unsigned integers can be used without +// ambiguity. + +/// Return the high 32 bits of a 64 bit value. +constexpr inline uint32_t Hi_32(uint64_t Value) { + return static_cast(Value >> 32); +} + +/// Return the low 32 bits of a 64 bit value. +constexpr inline uint32_t Lo_32(uint64_t Value) { + return static_cast(Value); +} + +/// Make a 64-bit integer from a high / low pair of 32-bit integers. +constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) { + return ((uint64_t)High << 32) | (uint64_t)Low; +} + +/// Checks if an integer fits into the given bit width. +template +constexpr inline bool isInt(int64_t x) { + return N >= 64 || + (-(INT64_C(1) << (N - 1)) <= x && x < (INT64_C(1) << (N - 1))); +} +// Template specializations to get better code for common cases. +template <> +constexpr inline bool isInt<8>(int64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isInt<16>(int64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isInt<32>(int64_t x) { + return static_cast(x) == x; +} + +/// Checks if a signed integer is an N bit number shifted left by S. +template +constexpr inline bool isShiftedInt(int64_t x) { + static_assert( + N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number."); + static_assert(N + S <= 64, "isShiftedInt with N + S > 64 is too wide."); + return isInt(x) && (x % (UINT64_C(1) << S) == 0); +} + +/// Checks if an unsigned integer fits into the given bit width. +/// +/// This is written as two functions rather than as simply +/// +/// return N >= 64 || X < (UINT64_C(1) << N); +/// +/// to keep MSVC from (incorrectly) warning on isUInt<64> that we're shifting +/// left too many places. +template +constexpr inline std::enable_if_t<(N < 64), bool> isUInt(uint64_t X) { + static_assert(N > 0, "isUInt<0> doesn't make sense"); + return X < (UINT64_C(1) << (N)); +} +template +constexpr inline std::enable_if_t= 64, bool> isUInt(uint64_t /*X*/) { + return true; +} + +// Template specializations to get better code for common cases. +template <> +constexpr inline bool isUInt<8>(uint64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isUInt<16>(uint64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isUInt<32>(uint64_t x) { + return static_cast(x) == x; +} + +/// Checks if a unsigned integer is an N bit number shifted left by S. +template +constexpr inline bool isShiftedUInt(uint64_t x) { + static_assert( + N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)"); + static_assert( + N + S <= 64, "isShiftedUInt with N + S > 64 is too wide."); + // Per the two static_asserts above, S must be strictly less than 64. So + // 1 << S is not undefined behavior. + return isUInt(x) && (x % (UINT64_C(1) << S) == 0); +} + +/// Gets the maximum value for a N-bit unsigned integer. +inline uint64_t maxUIntN(uint64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + + // uint64_t(1) << 64 is undefined behavior, so we can't do + // (uint64_t(1) << N) - 1 + // without checking first that N != 64. But this works and doesn't have a + // branch. + return UINT64_MAX >> (64 - N); +} + +// Ignore the false warning "Arithmetic overflow" for MSVC +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#endif + +/// Gets the minimum value for a N-bit signed integer. +inline int64_t minIntN(int64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + // NOLINTNEXTLINE(*-narrowing-conversions) + return -(UINT64_C(1) << (N - 1)); +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/// Gets the maximum value for a N-bit signed integer. +inline int64_t maxIntN(int64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + + // This relies on two's complement wraparound when N == 64, so we convert to + // int64_t only at the very end to avoid UB. + // NOLINTNEXTLINE(*-narrowing-conversions) + return (UINT64_C(1) << (N - 1)) - 1; +} + +/// Checks if an unsigned integer fits into the given (dynamic) bit width. +inline bool isUIntN(unsigned N, uint64_t x) { + return N >= 64 || x <= maxUIntN(N); +} + +/// Checks if an signed integer fits into the given (dynamic) bit width. +inline bool isIntN(unsigned N, int64_t x) { + return N >= 64 || (minIntN(N) <= x && x <= maxIntN(N)); +} + +/// Return true if the argument is a non-empty sequence of ones starting at the +/// least significant bit with the remainder zero (32 bit version). +/// Ex. isMask_32(0x0000FFFFU) == true. +constexpr inline bool isMask_32(uint32_t Value) { + return Value && ((Value + 1) & Value) == 0; +} + +/// Return true if the argument is a non-empty sequence of ones starting at the +/// least significant bit with the remainder zero (64 bit version). +constexpr inline bool isMask_64(uint64_t Value) { + return Value && ((Value + 1) & Value) == 0; +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true. +constexpr inline bool isShiftedMask_32(uint32_t Value) { + return Value && isMask_32((Value - 1) | Value); +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (64 bit version.) +constexpr inline bool isShiftedMask_64(uint64_t Value) { + return Value && isMask_64((Value - 1) | Value); +} + +/// Return true if the argument is a power of two > 0. +/// Ex. isPowerOf2_32(0x00100000U) == true (32 bit edition.) +constexpr inline bool isPowerOf2_32(uint32_t Value) { + return Value && !(Value & (Value - 1)); +} + +/// Return true if the argument is a power of two > 0 (64 bit edition.) +constexpr inline bool isPowerOf2_64(uint64_t Value) { + return Value && !(Value & (Value - 1)); +} + +/// Count the number of ones from the most significant bit to the first +/// zero bit. +/// +/// Ex. countLeadingOnes(0xFF0FFF00) == 8. +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of all ones. Only ZB_Width and +/// ZB_Undefined are valid arguments. +template +std::size_t countLeadingOnes(T Value, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return countLeadingZeros(~Value, ZB); +} + +/// Count the number of ones from the least significant bit to the first +/// zero bit. +/// +/// Ex. countTrailingOnes(0x00FF00FF) == 8. +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of all ones. Only ZB_Width and +/// ZB_Undefined are valid arguments. +template +std::size_t countTrailingOnes(T Value, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return countTrailingZeros(~Value, ZB); +} + +namespace detail { +template +struct PopulationCounter { + static unsigned count(T Value) { + // Generic version, forward to 32 bits. + static_assert(SizeOfT <= 4, "Not implemented!"); +#if defined(__GNUC__) && __GNUC__ >= 4 + return __builtin_popcount(Value); +#else + uint32_t v = Value; + v = v - ((v >> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >> 2) & 0x33333333); + return ((v + (v >> 4) & 0xF0F0F0F) * 0x1010101) >> 24; +#endif + } +}; + +template +struct PopulationCounter { + static unsigned count(T Value) { +#if defined(__GNUC__) && __GNUC__ >= 4 + return __builtin_popcountll(Value); +#else + uint64_t v = Value; + v = v - ((v >> 1) & 0x5555555555555555ULL); + v = (v & 0x3333333333333333ULL) + ((v >> 2) & 0x3333333333333333ULL); + v = (v + (v >> 4)) & 0x0F0F0F0F0F0F0F0FULL; + return unsigned((uint64_t)(v * 0x0101010101010101ULL) >> 56); +#endif + } +}; +} // namespace detail + +/// Count the number of set bits in a value. +/// Ex. countPopulation(0xF000F000) = 8 +/// Returns 0 if the word is zero. +template +inline unsigned countPopulation(T Value) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return detail::PopulationCounter::count(Value); +} + +/// Return the log base 2 of the specified value. +inline double Log2(double Value) { +#if defined(__ANDROID_API__) && __ANDROID_API__ < 18 + return __builtin_log(Value) / __builtin_log(2.0); +#else + return log2(Value); +#endif +} + +/// Return the floor log base 2 of the specified value, -1 if the value is zero. +/// (32 bit edition.) +/// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2 +inline unsigned Log2_32(uint32_t Value) { + return static_cast(31 - countLeadingZeros(Value)); +} + +/// Return the floor log base 2 of the specified value, -1 if the value is zero. +/// (64 bit edition.) +inline unsigned Log2_64(uint64_t Value) { + return static_cast(63 - countLeadingZeros(Value)); +} + +/// Return the ceil log base 2 of the specified value, 32 if the value is zero. +/// (32 bit edition). +/// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3 +inline unsigned Log2_32_Ceil(uint32_t Value) { + return static_cast(32 - countLeadingZeros(Value - 1)); +} + +/// Return the ceil log base 2 of the specified value, 64 if the value is zero. +/// (64 bit edition.) +inline unsigned Log2_64_Ceil(uint64_t Value) { + return static_cast(64 - countLeadingZeros(Value - 1)); +} + +/// Return the greatest common divisor of the values using Euclid's algorithm. +inline uint64_t GreatestCommonDivisor64(uint64_t A, uint64_t B) { + while (B) { + uint64_t T = B; + B = A % B; + A = T; + } + return A; +} + +/// This function takes a 64-bit integer and returns the bit equivalent double. +inline double BitsToDouble(uint64_t Bits) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + double D; + static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); + memcpy(&D, &Bits, sizeof(Bits)); + return D; +} + +/// This function takes a 32-bit integer and returns the bit equivalent float. +inline float BitsToFloat(uint32_t Bits) { + // TODO: Use std::bit_cast once C++20 becomes available. + return standalone::c10::bit_cast(Bits); +} + +/// This function takes a double and returns the bit equivalent 64-bit integer. +/// Note that copying doubles around changes the bits of NaNs on some hosts, +/// notably x86, so this routine cannot be used if these bits are needed. +inline uint64_t DoubleToBits(double Double) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint64_t Bits; + static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); + memcpy(&Bits, &Double, sizeof(Double)); + return Bits; +} + +/// This function takes a float and returns the bit equivalent 32-bit integer. +/// Note that copying floats around changes the bits of NaNs on some hosts, +/// notably x86, so this routine cannot be used if these bits are needed. +inline uint32_t FloatToBits(float Float) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t Bits; + static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes"); + memcpy(&Bits, &Float, sizeof(Float)); + return Bits; +} + +/// A and B are either alignments or offsets. Return the minimum alignment that +/// may be assumed after adding the two together. +constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) { + // The largest power of 2 that divides both A and B. + // + // Replace "-Value" by "1+~Value" in the following commented code to avoid + // MSVC warning C4146 + // return (A | B) & -(A | B); + return (A | B) & (1 + ~(A | B)); +} + +/// Aligns \c Addr to \c Alignment bytes, rounding up. +/// +/// Alignment should be a power of two. This method rounds up, so +/// alignAddr(7, 4) == 8 and alignAddr(8, 4) == 8. +inline uintptr_t alignAddr(const void* Addr, size_t Alignment) { + assert( + Alignment && isPowerOf2_64((uint64_t)Alignment) && + "Alignment is not a power of two!"); + + assert((uintptr_t)Addr + Alignment - 1 >= (uintptr_t)Addr); + + return (((uintptr_t)Addr + Alignment - 1) & ~(uintptr_t)(Alignment - 1)); +} + +/// Returns the necessary adjustment for aligning \c Ptr to \c Alignment +/// bytes, rounding up. +inline size_t alignmentAdjustment(const void* Ptr, size_t Alignment) { + return alignAddr(Ptr, Alignment) - (uintptr_t)Ptr; +} + +/// Returns the next power of two (in 64-bits) that is strictly greater than A. +/// Returns zero on overflow. +inline uint64_t NextPowerOf2(uint64_t A) { + A |= (A >> 1); + A |= (A >> 2); + A |= (A >> 4); + A |= (A >> 8); + A |= (A >> 16); + A |= (A >> 32); + return A + 1; +} + +/// Returns the power of two which is less than or equal to the given value. +/// Essentially, it is a floor operation across the domain of powers of two. +inline uint64_t PowerOf2Floor(uint64_t A) { + if (!A) + return 0; + return 1ull << (63 - countLeadingZeros(A, ZB_Undefined)); +} + +/// Returns the power of two which is greater than or equal to the given value. +/// Essentially, it is a ceil operation across the domain of powers of two. +inline uint64_t PowerOf2Ceil(uint64_t A) { + if (!A) + return 0; + return NextPowerOf2(A - 1); +} + +/// Returns the next integer (mod 2**64) that is greater than or equal to +/// \p Value and is a multiple of \p Align. \p Align must be non-zero. +/// +/// If non-zero \p Skew is specified, the return value will be a minimal +/// integer that is greater than or equal to \p Value and equal to +/// \p Align * N + \p Skew for some integer N. If \p Skew is larger than +/// \p Align, its value is adjusted to '\p Skew mod \p Align'. +/// +/// Examples: +/// \code +/// alignTo(5, 8) = 8 +/// alignTo(17, 8) = 24 +/// alignTo(~0LL, 8) = 0 +/// alignTo(321, 255) = 510 +/// +/// alignTo(5, 8, 7) = 7 +/// alignTo(17, 8, 1) = 17 +/// alignTo(~0LL, 8, 3) = 3 +/// alignTo(321, 255, 42) = 552 +/// \endcode +inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew = 0) { + assert(Align != 0u && "Align can't be 0."); + Skew %= Align; + return (Value + Align - 1 - Skew) / Align * Align + Skew; +} + +/// Returns the next integer (mod 2**64) that is greater than or equal to +/// \p Value and is a multiple of \c Align. \c Align must be non-zero. +template +constexpr inline uint64_t alignTo(uint64_t Value) { + static_assert(Align != 0u, "Align must be non-zero"); + return (Value + Align - 1) / Align * Align; +} + +/// Returns the integer ceil(Numerator / Denominator). +inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) { + return alignTo(Numerator, Denominator) / Denominator; +} + +/// \c alignTo for contexts where a constant expression is required. +/// \sa alignTo +/// +/// \todo FIXME: remove when \c constexpr becomes really \c constexpr +template +struct AlignTo { + static_assert(Align != 0u, "Align must be non-zero"); + template + struct from_value { + static const uint64_t value = (Value + Align - 1) / Align * Align; + }; +}; + +/// Returns the largest uint64_t less than or equal to \p Value and is +/// \p Skew mod \p Align. \p Align must be non-zero +inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) { + assert(Align != 0u && "Align can't be 0."); + Skew %= Align; + return (Value - Skew) / Align * Align + Skew; +} + +/// Returns the offset to the next integer (mod 2**64) that is greater than +/// or equal to \p Value and is a multiple of \p Align. \p Align must be +/// non-zero. +inline uint64_t OffsetToAlignment(uint64_t Value, uint64_t Align) { + return alignTo(Value, Align) - Value; +} + +/// Sign-extend the number in the bottom B bits of X to a 32-bit integer. +/// Requires 0 < B <= 32. +template +constexpr inline int32_t SignExtend32(uint32_t X) { + static_assert(B > 0, "Bit width can't be 0."); + static_assert(B <= 32, "Bit width out of range."); + return int32_t(X << (32 - B)) >> (32 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 32-bit integer. +/// Requires 0 < B < 32. +inline int32_t SignExtend32(uint32_t X, unsigned B) { + assert(B > 0 && "Bit width can't be 0."); + assert(B <= 32 && "Bit width out of range."); + return int32_t(X << (32 - B)) >> (32 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 64-bit integer. +/// Requires 0 < B < 64. +template +constexpr inline int64_t SignExtend64(uint64_t x) { + static_assert(B > 0, "Bit width can't be 0."); + static_assert(B <= 64, "Bit width out of range."); + return int64_t(x << (64 - B)) >> (64 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 64-bit integer. +/// Requires 0 < B < 64. +inline int64_t SignExtend64(uint64_t X, unsigned B) { + assert(B > 0 && "Bit width can't be 0."); + assert(B <= 64 && "Bit width out of range."); + return int64_t(X << (64 - B)) >> (64 - B); +} + +/// Subtract two unsigned integers, X and Y, of type T and return the absolute +/// value of the result. +template +std::enable_if_t, T> AbsoluteDifference(T X, T Y) { + return std::max(X, Y) - std::min(X, Y); +} + +/// Add two unsigned integers, X and Y, of type T. Clamp the result to the +/// maximum representable value of T on overflow. ResultOverflowed indicates if +/// the result is larger than the maximum representable value of type T. +template +std::enable_if_t, T> +SaturatingAdd(T X, T Y, bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + // Hacker's Delight, p. 29 + T Z = X + Y; + Overflowed = (Z < X || Z < Y); + if (Overflowed) + return std::numeric_limits::max(); + else + return Z; +} + +/// Multiply two unsigned integers, X and Y, of type T. Clamp the result to the +/// maximum representable value of T on overflow. ResultOverflowed indicates if +/// the result is larger than the maximum representable value of type T. +template +std::enable_if_t, T> +SaturatingMultiply(T X, T Y, bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + + // Hacker's Delight, p. 30 has a different algorithm, but we don't use that + // because it fails for uint16_t (where multiplication can have undefined + // behavior due to promotion to int), and requires a division in addition + // to the multiplication. + + Overflowed = false; + + // Log2(Z) would be either Log2Z or Log2Z + 1. + // Special case: if X or Y is 0, Log2_64 gives -1, and Log2Z + // will necessarily be less than Log2Max as desired. + int Log2Z = Log2_64(X) + Log2_64(Y); + const T Max = std::numeric_limits::max(); + int Log2Max = Log2_64(Max); + if (Log2Z < Log2Max) { + return X * Y; + } + if (Log2Z > Log2Max) { + Overflowed = true; + return Max; + } + + // We're going to use the top bit, and maybe overflow one + // bit past it. Multiply all but the bottom bit then add + // that on at the end. + T Z = (X >> 1) * Y; + if (Z & ~(Max >> 1)) { + Overflowed = true; + return Max; + } + Z <<= 1; + if (X & 1) + return SaturatingAdd(Z, Y, ResultOverflowed); + + return Z; +} + +/// Multiply two unsigned integers, X and Y, and add the unsigned integer, A to +/// the product. Clamp the result to the maximum representable value of T on +/// overflow. ResultOverflowed indicates if the result is larger than the +/// maximum representable value of type T. +template +std::enable_if_t, T> +SaturatingMultiplyAdd(T X, T Y, T A, bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + + T Product = SaturatingMultiply(X, Y, &Overflowed); + if (Overflowed) + return Product; + + return SaturatingAdd(A, Product, &Overflowed); +} + +/// Use this rather than HUGE_VALF; the latter causes warnings on MSVC. +extern const float huge_valf; +} // namespace standalone::c10::llvm diff --git a/backends/aoti/slim/c10/util/overflows.h b/backends/aoti/slim/c10/util/overflows.h new file mode 100644 index 00000000000..5f636cd1a75 --- /dev/null +++ b/backends/aoti/slim/c10/util/overflows.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace standalone::c10 { +// In some versions of MSVC, there will be a compiler error when building. +// C4146: unary minus operator applied to unsigned type, result still unsigned +// C4804: unsafe use of type 'bool' in operation +// It can be addressed by disabling the following warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#pragma warning(disable : 4804) +#pragma warning(disable : 4018) +#endif + +// The overflow checks may involve float to int conversion which may +// trigger precision loss warning. Re-enable the warning once the code +// is fixed. See T58053069. +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +// bool can be converted to any type. +// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: +// `error: comparison of constant '255' with boolean expression is always false` +// for `f > limit::max()` below +template +std::enable_if_t, bool> overflows( + From /*f*/, + bool strict_unsigned [[maybe_unused]] = false) { + return false; +} + +// skip isnan and isinf check for integral types +template +std::enable_if_t && !std::is_same_v, bool> +overflows(From f, bool strict_unsigned = false) { + using limit = std::numeric_limits::type>; + if constexpr (!limit::is_signed && std::numeric_limits::is_signed) { + // allow for negative numbers to wrap using two's complement arithmetic. + // For example, with uint8, this allows for `a - b` to be treated as + // `a + 255 * b`. + if (!strict_unsigned) { + return greater_than_max(f) || + (standalone::c10::is_negative(f) && + -static_cast(f) > static_cast(limit::max())); + } + } + return standalone::c10::less_than_lowest(f) || greater_than_max(f); +} + +template +std::enable_if_t, bool> overflows( + From f, + bool strict_unsigned [[maybe_unused]] = false) { + using limit = std::numeric_limits::type>; + if (limit::has_infinity && std::isinf(static_cast(f))) { + return false; + } + if (!limit::has_quiet_NaN && (f != f)) { + return true; + } + return f < limit::lowest() || f > limit::max(); +} + +STANDALONE_CLANG_DIAGNOSTIC_POP() + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template +std::enable_if_t::value, bool> overflows( + From f, + bool strict_unsigned = false) { + // casts from complex to real are considered to overflow if the + // imaginary component is non-zero + if (!is_complex::value && f.imag() != 0) { + return true; + } + // Check for overflow componentwise + // (Technically, the imag overflow check is guaranteed to be false + // when !is_complex, but any optimizer worth its salt will be + // able to figure it out.) + return overflows< + typename scalar_value_type::type, + typename From::value_type>(f.real(), strict_unsigned) || + overflows< + typename scalar_value_type::type, + typename From::value_type>(f.imag(), strict_unsigned); +} +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/qint32.h b/backends/aoti/slim/c10/util/qint32.h new file mode 100644 index 00000000000..7951bfd240a --- /dev/null +++ b/backends/aoti/slim/c10/util/qint32.h @@ -0,0 +1,18 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * qint32 is for signed 32 bit quantized Tensors + */ +struct alignas(4) qint32 { + using underlying = int32_t; + int32_t val_; + qint32() = default; + STANDALONE_HOST_DEVICE explicit qint32(int32_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/qint8.h b/backends/aoti/slim/c10/util/qint8.h new file mode 100644 index 00000000000..53c1fdf465a --- /dev/null +++ b/backends/aoti/slim/c10/util/qint8.h @@ -0,0 +1,20 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * This is the data type for quantized Tensors. Right now we only have + * qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors, + * we might have 4 bit, 2 bit or 1 bit data types in the future. + */ +struct alignas(1) qint8 { + using underlying = int8_t; + int8_t val_; + qint8() = default; + STANDALONE_HOST_DEVICE explicit qint8(int8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/quint2x4.h b/backends/aoti/slim/c10/util/quint2x4.h new file mode 100644 index 00000000000..009802be7f2 --- /dev/null +++ b/backends/aoti/slim/c10/util/quint2x4.h @@ -0,0 +1,19 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * quint2x4 is for un-signed 2 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint2x4 { + using underlying = uint8_t; + uint8_t val_; + quint2x4() = default; + STANDALONE_HOST_DEVICE explicit quint2x4(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/quint4x2.h b/backends/aoti/slim/c10/util/quint4x2.h new file mode 100644 index 00000000000..b6812ab8fde --- /dev/null +++ b/backends/aoti/slim/c10/util/quint4x2.h @@ -0,0 +1,19 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint4x2 { + using underlying = uint8_t; + uint8_t val_; + quint4x2() = default; + STANDALONE_HOST_DEVICE explicit quint4x2(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/quint8.h b/backends/aoti/slim/c10/util/quint8.h new file mode 100644 index 00000000000..4019765ca4a --- /dev/null +++ b/backends/aoti/slim/c10/util/quint8.h @@ -0,0 +1,18 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * quint8 is for unsigned 8 bit quantized Tensors + */ +struct alignas(1) quint8 { + using underlying = uint8_t; + uint8_t val_; + quint8() = default; + STANDALONE_HOST_DEVICE explicit quint8(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/safe_numerics.h b/backends/aoti/slim/c10/util/safe_numerics.h new file mode 100644 index 00000000000..26a05c636aa --- /dev/null +++ b/backends/aoti/slim/c10/util/safe_numerics.h @@ -0,0 +1,94 @@ +#pragma once +#include + +#include + +// GCC has __builtin_mul_overflow from before it supported __has_builtin +#ifdef _MSC_VER +#define STANDALONE_HAS_BUILTIN_OVERFLOW() (0) +#include +#include +#else +#define STANDALONE_HAS_BUILTIN_OVERFLOW() (1) +#endif + +namespace standalone::c10 { + +STANDALONE_ALWAYS_INLINE bool +add_overflows(uint64_t a, uint64_t b, uint64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + return __builtin_add_overflow(a, b, out); +#else + unsigned long long tmp; +#if defined(_M_IX86) || defined(_M_X64) + auto carry = _addcarry_u64(0, a, b, &tmp); +#else + tmp = a + b; + unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); + auto carry = vector >> 63; +#endif + *out = tmp; + return carry; +#endif +} + +STANDALONE_ALWAYS_INLINE bool +mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + *out = a * b; + // This test isnt exact, but avoids doing integer division + return ( + (standalone::c10::llvm::countLeadingZeros(a) + + standalone::c10::llvm::countLeadingZeros(b)) < 64); +#endif +} + +STANDALONE_ALWAYS_INLINE bool +mul_overflows(int64_t a, int64_t b, int64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + volatile int64_t tmp = a * b; + *out = tmp; + if (a == 0 || b == 0) { + return false; + } + return !(a == tmp / b); +#endif +} + +template +bool safe_multiplies_u64(It first, It last, uint64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + uint64_t prod = 1; + bool overflow = false; + for (; first != last; ++first) { + overflow |= standalone::c10::mul_overflows(prod, *first, &prod); + } + *out = prod; + return overflow; +#else + uint64_t prod = 1; + uint64_t prod_log2 = 0; + bool is_zero = false; + for (; first != last; ++first) { + auto x = static_cast(*first); + prod *= x; + // log2(0) isn't valid, so need to track it specially + is_zero |= (x == 0); + prod_log2 += standalone::c10::llvm::Log2_64_Ceil(x); + } + *out = prod; + // This test isnt exact, but avoids doing integer division + return !is_zero && (prod_log2 >= 64); +#endif +} + +template +bool safe_multiplies_u64(const Container& c, uint64_t* out) { + return safe_multiplies_u64(c.begin(), c.end(), out); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/core/SlimTensor.h b/backends/aoti/slim/core/SlimTensor.h new file mode 100644 index 00000000000..69ac4fec65f --- /dev/null +++ b/backends/aoti/slim/core/SlimTensor.h @@ -0,0 +1,637 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace standalone::slim { + +class SlimTensor { + public: + SlimTensor( + Storage&& storage, + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + int64_t storage_offset = 0) + : storage_(std::move(storage)), + storage_offset_(storage_offset), + dtype_(dtype) { + set_sizes_and_strides(sizes, strides); + } + + // Default constructor - creates an undefined tensor + SlimTensor() + : storage_(Storage()), + storage_offset_(0), + numel_(0), + dtype_(standalone::c10::ScalarType::Float), + is_contiguous_(true) { + sizes_and_strides_.set_sizes({0}); + sizes_and_strides_.set_strides({1}); + } + + SlimTensor(const SlimTensor&) = default; + SlimTensor& operator=(const SlimTensor&) = default; + SlimTensor(SlimTensor&&) = default; + SlimTensor& operator=(SlimTensor&&) = default; + + ~SlimTensor() = default; + + void reset() { + // Decrement the refcount of the storage + storage_.reset(); + } + + // Accessors + Storage storage() const { + return storage_; + } + + size_t nbytes() const { + return numel() * itemsize(); + } + + size_t itemsize() const { + return standalone::c10::elementSize(dtype_); + } + + standalone::c10::IntArrayRef sizes() const { + return sizes_and_strides_.sizes_arrayref(); + } + + int64_t size(int64_t dim) const { + int64_t wrapped_dim = + standalone::c10::maybe_wrap_dim(dim, static_cast(this->dim())); + return sizes_and_strides_.size_at(static_cast(wrapped_dim)); + } + + standalone::c10::IntArrayRef strides() const { + return sizes_and_strides_.strides_arrayref(); + } + + int64_t stride(int64_t dim) const { + int64_t wrapped_dim = + standalone::c10::maybe_wrap_dim(dim, static_cast(this->dim())); + return sizes_and_strides_.stride_at(static_cast(wrapped_dim)); + } + + standalone::c10::ScalarType dtype() const { + return dtype_; + } + + const standalone::c10::Device& device() const { + return storage_->device(); + } + + standalone::c10::DeviceType device_type() const { + return storage_->device().type(); + } + + standalone::c10::DeviceIndex device_index() const { + return storage_->device().index(); + } + + int64_t storage_offset() const { + return storage_offset_; + } + + size_t numel() const { + return numel_; + } + + size_t dim() const { + return sizes_and_strides_.size(); + } + + void* data_ptr() const { + return static_cast(storage_->data()) + storage_offset_ * itemsize(); + } + + bool is_contiguous() const { + return is_contiguous_; + } + + bool is_empty() const { + return numel_ == 0; + } + + bool is_cuda() const { + return device().is_cuda(); + } + + bool is_cpu() const { + return device().is_cpu(); + } + + // Check if tensor is defined (not default-constructed) + bool defined() const { + return storage_.get() != nullptr; + } + + // Setters + void set_storage(Storage&& new_storage) { + storage_ = std::move(new_storage); + } + + void set_sizes_and_strides( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + std::optional storage_offset = std::nullopt) { + const int64_t new_dim = static_cast(sizes.size()); + STANDALONE_CHECK( + new_dim == static_cast(strides.size()), + "dimensionality of sizes (", + new_dim, + ") must match dimensionality of strides (", + strides.size(), + ")"); + + std::vector new_sizes = sizes.vec(); + std::vector new_strides = strides.vec(); + + // stride calculation logic + bool overflowed = false; + if (new_dim > 0) { + for (int64_t dim = new_dim - 1; dim >= 0; dim--) { + if (strides[dim] >= 0) { + new_strides[dim] = strides[dim]; + } else { + // for negative strides + if (dim == new_dim - 1) { + new_strides[dim] = 1; + } else { + overflowed |= standalone::c10::mul_overflows( + new_strides[dim + 1], + std::max(new_sizes[dim + 1], 1), + &new_strides[dim]); + } + } + } + } + STANDALONE_CHECK(!overflowed, "Stride calculation overflowed"); + + sizes_and_strides_.set_sizes(new_sizes); + sizes_and_strides_.set_strides(new_strides); + if (storage_offset.has_value()) { + storage_offset_ = *storage_offset; + } + + refresh_numel(); + refresh_contiguous(); + } + + void set_sizes_contiguous(standalone::c10::IntArrayRef new_size) { + sizes_and_strides_.set_sizes(new_size); + refresh_numel(); + empty_tensor_restride(standalone::c10::MemoryFormat::Contiguous); + } + + void empty_tensor_restride(standalone::c10::MemoryFormat memory_format); + + SlimTensor resize_( + standalone::c10::IntArrayRef sizes, + std::optional optional_memory_format); + + // Conversion operations + SlimTensor to(const standalone::c10::Device& device) const { + if (device == storage_->device()) { + return *this; + } + // Does not mutate the current tensor. Returns a new tensor + Storage new_storage(new MaybeOwningStorage(storage_->clone(device))); + return SlimTensor( + std::move(new_storage), + sizes_and_strides_.sizes_arrayref(), + sizes_and_strides_.strides_arrayref(), + dtype_, + storage_offset_); + } + + SlimTensor cpu() const { + return to(CPU_DEVICE); + } + + SlimTensor cuda() const { + return to(DEFAULT_CUDA_DEVICE); + } + + SlimTensor to(standalone::c10::ScalarType dtype) const { + STANDALONE_CHECK(false, "TBD: to(dtype)"); + } + + SlimTensor& copy_(const SlimTensor& other) { + STANDALONE_CHECK( + this->numel() == other.numel(), "copy_: numel of tensors must match"); + STANDALONE_CHECK(this->dtype() == other.dtype(), "copy_: dtype must match"); + + if (this->numel() == 0) { + return *this; + } + + // Case 1: Both tensors are contiguous. We can do a fast bulk copy. + if (this->is_contiguous() && other.is_contiguous()) { + storage_->copy_( + this->data_ptr(), other.data_ptr(), other.nbytes(), other.device()); + return *this; + } + + // Case 2: At least one tensor is non-contiguous, perform element-wise copy + // that respects both source and destination strides. + const size_t elem_size = standalone::c10::elementSize(dtype_); + char* dst_data = static_cast(this->data_ptr()); + const char* src_data = static_cast(other.data_ptr()); + + std::vector counter(this->dim(), 0); + for (size_t i = 0; i < this->numel(); i++) { + // Compute src offset in elements + int64_t src_offset = 0; + for (size_t d = 0; d < other.dim(); d++) { + src_offset += counter[d] * other.stride(d); + } + + // Compute dst offset in elements + int64_t dst_offset = 0; + for (size_t d = 0; d < this->dim(); d++) { + dst_offset += counter[d] * this->stride(d); + } + + // Copy elem_size bytes from src to dst + if (this->device().is_cpu() && other.device().is_cpu()) { + std::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size); + } else if (this->device().is_cuda() || other.device().is_cuda()) { +#if defined(USE_CUDA) + DeviceTraits::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size, + device(), // dst device + other.device() // src device + ); +#else + STANDALONE_CHECK(false, "copy_: no CUDA support"); +#endif + } + // Increment the multi-dimensional counter + for (int64_t d = static_cast(this->dim()) - 1; d >= 0; --d) { + counter[d]++; + if (counter[d] < this->size(d)) { + break; + } + counter[d] = 0; + } + } + return *this; + } + + SlimTensor& fill_(const c10::Scalar& value) { + // Fast path for byte patterns on contiguous tensors - use memset + if (value.equal(0) && this->is_contiguous()) { + if (this->device().is_cpu()) { + std::memset(this->data_ptr(), 0, this->nbytes()); + return *this; + } else if (this->device().is_cuda()) { +#ifdef USE_CUDA + cudaError_t err = cudaMemset(this->data_ptr(), 0, this->nbytes()); + STANDALONE_CHECK( + err == cudaSuccess, + "CUDA memset failed: ", + cudaGetErrorString(err)); + return *this; +#else + STANDALONE_CHECK(false, "CUDA support not available"); +#endif + } + } + + // Fallback to type-specific fill implementation + auto fill_value = [&](auto typed_value) { + using SType = decltype(typed_value); + if (this->device().is_cuda()) { +#ifdef USE_CUDA + if (this->is_contiguous()) { + // Fast path for contiguous tensors + if constexpr (std::is_same_v) { + // Special handling for bool since std::vector doesn't have + // data() + std::vector host_data(this->numel(), typed_value ? 1 : 0); + cudaError_t err = cudaMemcpy( + this->data_ptr(), + host_data.data(), + this->nbytes(), + cudaMemcpyHostToDevice); + STANDALONE_CHECK( + err == cudaSuccess, + "CUDA memcpy failed: ", + cudaGetErrorString(err)); + } else { + std::vector host_data(this->numel(), typed_value); + cudaError_t err = cudaMemcpy( + this->data_ptr(), + host_data.data(), + this->nbytes(), + cudaMemcpyHostToDevice); + STANDALONE_CHECK( + err == cudaSuccess, + "CUDA memcpy failed: ", + cudaGetErrorString(err)); + } + } else { + // Handle non-contiguous tensors by copying to CPU, filling, then + // copying back + SlimTensor cpu_tensor = this->to(CPU_DEVICE); + cpu_tensor.fill_(typed_value); + this->copy_(cpu_tensor); + } +#else + STANDALONE_CHECK(false, "CUDA support not available"); +#endif + } else if (this->device().is_cpu()) { + if (this->is_contiguous()) { + // Fast path for contiguous tensors + SType* data = static_cast(this->data_ptr()); + for (size_t i = 0; i < this->numel(); ++i) { + data[i] = typed_value; + } + } else { + // Handle non-contiguous tensors by respecting strides + const size_t elem_size = standalone::c10::elementSize(this->dtype_); + char* base_data = static_cast(this->data_ptr()); + + std::vector counter(this->dim(), 0); + for (size_t i = 0; i < this->numel(); ++i) { + // Compute offset in elements based on strides + int64_t offset = 0; + for (size_t d = 0; d < this->dim(); d++) { + offset += counter[d] * this->stride(d); + } + + // Set the value at the computed offset + SType* element_ptr = + reinterpret_cast(base_data + offset * elem_size); + *element_ptr = typed_value; + + // Increment the multi-dimensional counter + for (int64_t d = static_cast(this->dim()) - 1; d >= 0; + --d) { + counter[d]++; + if (counter[d] < this->size(d)) { + break; + } + counter[d] = 0; + } + } + } + } + }; + + switch (this->dtype()) { + case standalone::c10::ScalarType::Double: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Float: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Half: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::BFloat16: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Long: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Int: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Short: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Char: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Byte: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Bool: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::ComplexFloat: + fill_value(value.to>()); + break; + case standalone::c10::ScalarType::ComplexDouble: + fill_value(value.to>()); + break; + default: + STANDALONE_CHECK(false, "fill_: Unsupported dtype"); + } + return *this; + } + + SlimTensor clone() const { + return _clone_impl( + this->sizes(), this->strides(), this->dtype(), this->device()); + } + + SlimTensor clone_contiguous() const { + std::vector contig_strides = + standalone::slim::compute_contiguous_strides(this->sizes()); + return _clone_impl( + this->sizes(), contig_strides, this->dtype(), this->device()); + } + + // View operations + SlimTensor as_strided( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset) const; + SlimTensor as_strided_( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset); + + SlimTensor permute(standalone::c10::IntArrayRef dims) const; + + // Transpose operations + SlimTensor transpose() const; + SlimTensor transpose(int64_t dim0, int64_t dim1) const; + SlimTensor t() const; + + SlimTensor reshape(standalone::c10::IntArrayRef proposed_shape) const; + + SlimTensor narrow(int64_t dim, int64_t start, int64_t length) const; + + // Generic element access returning SlimTensor + SlimTensor operator[](standalone::c10::IntArrayRef indices) const { + STANDALONE_CHECK( + indices.size() <= this->dim(), + "Number of indices (", + indices.size(), + ") cannot exceed tensor dimensions (", + this->dim(), + ")"); + + if (indices.size() == this->dim()) { + // Full indexing - return 0-dimensional tensor + int64_t linear_index = 0; + for (size_t i = 0; i < indices.size(); ++i) { + int64_t idx = indices[i]; + int64_t size = this->size(i); + idx = standalone::c10::maybe_wrap_dim(idx, size); + linear_index += idx * this->stride(i); + } + // Create 0-dimensional tensor pointing to the indexed element + int64_t new_storage_offset = this->storage_offset_ + linear_index; + return SlimTensor( + Storage(this->storage_), {}, {}, this->dtype_, new_storage_offset); + } else { + // Partial indexing - return tensor with reduced dimensions + std::vector new_sizes; + std::vector new_strides; + int64_t offset_adjustment = 0; + + // Calculate offset from the provided indices + for (size_t i = 0; i < indices.size(); ++i) { + int64_t idx = indices[i]; + int64_t size = this->size(i); + idx = standalone::c10::maybe_wrap_dim(idx, size); + offset_adjustment += idx * this->stride(i); + } + + // Copy remaining dimensions + for (size_t i = indices.size(); i < this->dim(); ++i) { + new_sizes.push_back(this->size(i)); + new_strides.push_back(this->stride(i)); + } + + int64_t new_storage_offset = this->storage_offset_ + offset_adjustment; + return SlimTensor( + Storage(this->storage_), + new_sizes, + new_strides, + this->dtype_, + new_storage_offset); + } + } + + // Convenience overload for single index + SlimTensor operator[](int64_t index) const { + return (*this)[standalone::c10::IntArrayRef{index}]; + } + + // Convenience overloads for common multi-dimensional cases + SlimTensor operator[](std::initializer_list indices) const { + return (*this)[standalone::c10::IntArrayRef(indices)]; + } + + // Extract scalar value from 0-dimensional tensor + standalone::c10::Scalar item() const { + switch (this->dtype()) { + case standalone::c10::ScalarType::Double: + return this->item(); + case standalone::c10::ScalarType::Float: + return this->item(); + case standalone::c10::ScalarType::Half: + return this->item(); + case standalone::c10::ScalarType::BFloat16: + return this->item(); + case standalone::c10::ScalarType::Long: + return this->item(); + case standalone::c10::ScalarType::Int: + return this->item(); + case standalone::c10::ScalarType::Short: + return this->item(); + case standalone::c10::ScalarType::Char: + return this->item(); + case standalone::c10::ScalarType::Byte: + return this->item(); + case standalone::c10::ScalarType::Bool: + return this->item(); + case standalone::c10::ScalarType::ComplexFloat: + return this->item>(); + case standalone::c10::ScalarType::ComplexDouble: + return this->item>(); + default: + STANDALONE_CHECK(false, "item(): Unsupported dtype"); + } + } + + // Templated version to access 0-dimensional tensor + template + T item() const { + STANDALONE_CHECK( + this->dim() == 0, "item() can only be called on 0-dimensional tensors"); + STANDALONE_CHECK( + this->numel() == 1, "item() requires tensor to have exactly 1 element"); + + // For 0-dimensional tensors, directly access the single element at + // data_ptr() No need to compute linear index since there's only one element + const T* data = static_cast(this->data_ptr()); + return *data; + } + + private: + SlimTensor _clone_impl( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device) const { + Storage storage = new_storage(sizes, strides, dtype, device); + SlimTensor result = + SlimTensor(std::move(storage), sizes, strides, dtype, 0); + result.copy_(*this); + return result; + } + + void refresh_numel() { + numel_ = compute_numel(sizes_and_strides_.sizes_arrayref()); + } + + bool compute_is_contiguous() const { + return standalone::c10::_compute_contiguous( + sizes_and_strides_.sizes_arrayref(), + sizes_and_strides_.strides_arrayref(), + numel_); + } + + void refresh_contiguous() { + // In SlimTensor, we only care about the single is_contiguous_ flag. + // (because TensorImpl (aten) implementation has other stuff) + is_contiguous_ = compute_is_contiguous(); + } + + Storage storage_; // device_type_ and device_index_ are stored in storage_ + int64_t storage_offset_{0}; + standalone::c10::SizesAndStrides sizes_and_strides_; + // If sizes and strides are empty, the numel is 1!! However, most of the + // time, we will immediately set sizes to {0} and reset numel to 0. + // (Can't do that in the default initializers, because there's no way to + // spell "allocate a one-element array" for strides_). + size_t numel_{1}; + standalone::c10::ScalarType dtype_; + bool is_contiguous_{true}; + // NOLINTNEXTLINE(clang-diagnostic-unused-private-field) + std::array reserved_{0}; // padding to align to 8 bytes +}; + +} // namespace standalone::slim + +#include +#include diff --git a/backends/aoti/slim/core/SlimTensorResize-incl.h b/backends/aoti/slim/core/SlimTensorResize-incl.h new file mode 100644 index 00000000000..e9de9f5e0a6 --- /dev/null +++ b/backends/aoti/slim/core/SlimTensorResize-incl.h @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include +#include +#include + +namespace standalone::slim { +inline void SlimTensor::empty_tensor_restride( + standalone::c10::MemoryFormat memory_format) { +#ifdef DEBUG + STANDALONE_INTERNAL_ASSERT( + compute_numel() == numel_, + "If you are seeing this error, that means empty_tensor_restride was " + "called before setting correct numel"); +#endif + switch (memory_format) { + case standalone::c10::MemoryFormat::Contiguous: { + // dim_ is a virtual call, don't repeat it + const auto dim_ = dim(); + sizes_and_strides_.resize(dim_); + if (dim_ > 0) { + bool overflowed = false; + const auto last_idx = dim_ - 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; + for (int64_t i = static_cast(last_idx) - 1; i >= 0; --i) { + overflowed |= standalone::c10::mul_overflows( + sizes_and_strides_.stride_at_unchecked(i + 1), + std::max(sizes_and_strides_.size_at_unchecked(i + 1), 1), + std::addressof(sizes_and_strides_.stride_at_unchecked(i))); + } + STANDALONE_CHECK(!overflowed, "Stride calculation overflowed"); + } + break; + } + case standalone::c10::MemoryFormat::ChannelsLast: { + STANDALONE_CHECK( + dim() == 4, "required rank 4 tensor to use channels_last format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes())); + break; + } + case standalone::c10::MemoryFormat::ChannelsLast3d: { + STANDALONE_CHECK( + dim() == 5, "required rank 5 tensor to use channels_last_3d format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_3d(sizes())); + break; + } + case standalone::c10::MemoryFormat::Preserve: + STANDALONE_CHECK(false, "unsupported memory format ", memory_format); + // Cleaning warning messages, no need to break as STANDALONE_CHECK(false) + // terminates flow. + // break; + case standalone::c10::MemoryFormat::NumOptions: + STANDALONE_INTERNAL_ASSERT( + false, "invalid memory format ", memory_format); + } + // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually + // exclusive see #24090 + refresh_contiguous(); +} + +inline void _resize_bytes( + MaybeOwningStorage* storage, + size_t new_size_bytes, + size_t storage_offset_in_bytes) { + STANDALONE_CHECK( + storage->is_resizable(), + "Trying to resize storage that is not resizable"); + + void* new_data = nullptr; + const c10::Device& device = storage->device(); + if (new_size_bytes > 0) { + if (device.is_cpu()) { + new_data = + DeviceTraits::allocate(new_size_bytes, device); + } else if (device.is_cuda()) { + new_data = + DeviceTraits::allocate(new_size_bytes, device); + } + } + + void* old_data = storage->data(); + const size_t old_capacity = storage->nbytes(); + const size_t copy_capacity = std::min(new_size_bytes, old_capacity); + if (old_data != nullptr && copy_capacity > 0) { + if (device.is_cpu()) { + DeviceTraits::memcpy( + static_cast(new_data) + storage_offset_in_bytes, + static_cast(old_data) + storage_offset_in_bytes, + copy_capacity, + device, + device); + } else if (device.is_cuda()) { + DeviceTraits::memcpy( + static_cast(new_data) + storage_offset_in_bytes, + static_cast(old_data) + storage_offset_in_bytes, + copy_capacity, + device, + device); + } + } + + storage->free_data(); + storage->set_data_ptr_noswap(new_data); + storage->set_nbytes(new_size_bytes); +} + +inline void _maybe_resize_storage(SlimTensor* self, int64_t new_size_bytes) { + if (self->numel() == 0) { + return; + } + + const Storage& storage = self->storage(); + if (!storage) { + Storage new_storage(new MaybeOwningStorage(self->device(), new_size_bytes)); + self->set_storage(std::move(new_storage)); + } else if (new_size_bytes > static_cast(self->nbytes())) { + _resize_bytes( + storage.get(), + new_size_bytes, + self->storage_offset() * self->itemsize()); + } +} + +inline SlimTensor* _resize_impl_( + SlimTensor* self, + standalone::c10::IntArrayRef sizes, + std::optional strides, + bool resize_storage) { + if (self->sizes() == sizes && + (!strides || self->strides() == strides.value())) { + return self; + } + + const auto itemsize = self->itemsize(); + const auto storage_offset = self->storage_offset(); + int64_t storage_size = 1; + if (strides) { + self->set_sizes_and_strides(sizes, *strides); + storage_size = + compute_storage_nbytes(sizes, *strides, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(sizes); + storage_size = + compute_storage_nbytes_contiguous(sizes, itemsize, storage_offset); + } + + if (resize_storage) { + _maybe_resize_storage(self, storage_size); + } + + return self; +} + +inline SlimTensor SlimTensor::resize_( + standalone::c10::IntArrayRef sizes, + std::optional optional_memory_format) { + _resize_impl_(this, sizes, /*stride=*/std::nullopt, true); + + if (optional_memory_format.has_value()) { + standalone::c10::MemoryFormat memory_format = + static_cast( + optional_memory_format.value()); + STANDALONE_CHECK( + memory_format != standalone::c10::MemoryFormat::Preserve, + "Unsupported memory format", + memory_format); + this->empty_tensor_restride(memory_format); + } + return *this; +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/core/SlimTensorView-incl.h b/backends/aoti/slim/core/SlimTensorView-incl.h new file mode 100644 index 00000000000..0df4c4705f1 --- /dev/null +++ b/backends/aoti/slim/core/SlimTensorView-incl.h @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include +#include +#include + +namespace standalone::slim { +inline SlimTensor SlimTensor::as_strided( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset) const { + SlimTensor result = *this; + result.as_strided_(sizes, strides, storage_offset); + return result; +} + +inline SlimTensor SlimTensor::as_strided_( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset) { + STANDALONE_CHECK( + sizes.size() == strides.size(), + "as_strided: number of sizes (", + sizes.size(), + ") must equal number of strides (", + strides.size(), + ")"); + for (size_t i = 0; i < sizes.size(); ++i) { + STANDALONE_CHECK( + sizes[i] >= 0, + "as_strided: size at dimension ", + i, + " is negative: ", + sizes[i]); + } + STANDALONE_CHECK( + storage_offset >= 0, + "as_strided: storage_offset must be non-negative, got: ", + storage_offset); + + this->set_sizes_and_strides(sizes, strides, storage_offset); + return *this; +} + +inline SlimTensor SlimTensor::permute(standalone::c10::IntArrayRef dims) const { + const size_t ndim = this->dim(); + STANDALONE_CHECK( + ndim == static_cast(dims.size()), + "permute: dims length must be equal to tensor.dim()") + + standalone::c10::ArrayRef old_sizes = this->sizes(); + standalone::c10::ArrayRef old_strides = this->strides(); + std::vector new_sizes = old_sizes.vec(); + std::vector new_strides = old_strides.vec(); + std::vector seen_dims(ndim, false); + + for (size_t i = 0; i < ndim; i++) { + int64_t d = standalone::c10::maybe_wrap_dim(dims[i], ndim); + STANDALONE_CHECK(!seen_dims[d], "permute: duplicate dims are not allowed"); + seen_dims[d] = true; + new_sizes[i] = old_sizes[d]; + new_strides[i] = old_strides[d]; + } + + SlimTensor result = *this; + result.as_strided_(new_sizes, new_strides, this->storage_offset()); + return result; +} + +inline SlimTensor SlimTensor::transpose() const { + STANDALONE_CHECK(dim() == 2, "transpose() can only be called on 2D tensors"); + return permute({1, 0}); +} + +inline SlimTensor SlimTensor::transpose(int64_t dim0, int64_t dim1) const { + const size_t ndim = this->dim(); + std::vector dims; + for (size_t i = 0; i < ndim; i++) { + dims.push_back(static_cast(i)); + } + + // Wrap dimensions and swap them + dim0 = standalone::c10::maybe_wrap_dim(dim0, ndim); + dim1 = standalone::c10::maybe_wrap_dim(dim1, ndim); + std::swap(dims[dim0], dims[dim1]); + + return permute(dims); +} + +inline SlimTensor SlimTensor::t() const { + return transpose(); +} + +inline SlimTensor SlimTensor::reshape( + standalone::c10::IntArrayRef proposed_shape) const { + std::vector final_shape_vec = + infer_size(proposed_shape, this->numel()); + + // `compute_stride` return the proper strides to use if this + // `reshape` can be just a view. + std::optional> new_strides_opt = + compute_stride(this->sizes(), this->strides(), final_shape_vec); + + // create a view if possible + if (new_strides_opt.has_value()) { + SlimTensor result = *this; + result.as_strided_( + final_shape_vec, new_strides_opt.value(), this->storage_offset()); + return result; + } + + // if a view is not possible, create a contiguous clone and reshape that + SlimTensor contiguous_clone = this->clone_contiguous(); + // after cloning, the tensor is already contiguous. We just need to update + // its metadata to reflect the new shape. This is effectively a view of + // the new contiguous clone + contiguous_clone.set_sizes_contiguous(final_shape_vec); + return contiguous_clone; +} + +inline SlimTensor SlimTensor::narrow(int64_t dim, int64_t start, int64_t length) + const { + STANDALONE_CHECK( + this->dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + dim = standalone::c10::maybe_wrap_dim(dim, static_cast(this->dim())); + start = standalone::c10::maybe_wrap_dim( + start, static_cast(this->size(dim))); + + STANDALONE_CHECK(length >= 0, "narrow(): length must be non-negative."); + int64_t end = start + length; + STANDALONE_CHECK( + end <= this->size(dim), + "Invalid range to narrow. range(", + start, + ", ", + start + length, + ") must be a subset of range(0, ", + this->size(dim), + ")."); + + SlimTensor result = *this; + int64_t new_storage_offset = + this->storage_offset() + start * this->stride(dim); + std::vector new_sizes = this->sizes().vec(); + new_sizes[dim] = length; + result.as_strided_(new_sizes, this->strides(), new_storage_offset); + return result; +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h new file mode 100644 index 00000000000..4230a0d2b0a --- /dev/null +++ b/backends/aoti/slim/core/Storage.h @@ -0,0 +1,307 @@ +#pragma once +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#include +#endif + +#include +#include +#include +#include +#include +#include + +namespace standalone::slim { +using DeleterFn = void (*)(void*); + +namespace detail { +inline void noop(void*) {} +} // namespace detail + +const standalone::c10::Device CPU_DEVICE = + standalone::c10::Device(standalone::c10::DeviceType::CPU, 0); + +const standalone::c10::Device DEFAULT_CUDA_DEVICE = + standalone::c10::Device(standalone::c10::DeviceType::CUDA, 0); + +// standalone::c10::Device traits template for device-specific operations +template +struct DeviceTraits; + +// CPU specialization +template <> +struct DeviceTraits { + static void* allocate( + size_t nbytes, + const standalone::c10::Device& device = CPU_DEVICE) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + return malloc(nbytes); + } + + static void free(void* ptr) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + std::free(ptr); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const standalone::c10::Device& dst_device, + const standalone::c10::Device& src_device) { + std::memcpy(dst, src, nbytes); + } +}; + +// CUDA specialization +#ifdef USE_CUDA +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const standalone::c10::Device& device) { + standalone::slim::cuda::CUDAGuard guard(device); + void* data = nullptr; + STANDALONE_CUDA_CHECK(cudaMalloc(&data, nbytes)); + return data; + } + + static void free(void* ptr) { + STANDALONE_CUDA_CHECK_WARN(cudaFree(ptr)); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const standalone::c10::Device& dst_device, + const standalone::c10::Device& src_device) { + // Determine the direction + cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; + standalone::c10::Device cuda_device = + dst_device; // Default to destination device + + if (src_device.is_cpu()) { + direction = cudaMemcpyHostToDevice; + } else if (dst_device.is_cpu()) { + direction = cudaMemcpyDeviceToHost; + cuda_device = src_device; // Use source CUDA device + } else { + STANDALONE_CHECK( + src_device.index() == dst_device.index(), + "CUDA memcpy failed across different device indices: ", + src_device.index(), + "!=", + dst_device.index()); + } + // Set up CUDA context for the appropriate device + standalone::slim::cuda::CUDAGuard guard(cuda_device); + STANDALONE_CUDA_CHECK(cudaMemcpy(dst, src, nbytes, direction)); + } +}; +#else +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const standalone::c10::Device& device) { + STANDALONE_CHECK(false, "Build with USE_CUDA=1 to enable CUDA support"); + } + + static void free(void* ptr) { + STANDALONE_WARN("Build with USE_CUDA=1 to enable CUDA support"); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const standalone::c10::Device& dst_device, + const standalone::c10::Device& src_device) { + STANDALONE_CHECK(false, "Build with USE_CUDA=1 to enable CUDA support"); + } +}; +#endif + +// Storage can be either owning or non-owning. For AOTI-generated intermediate +// tensors, the storage is always owning. For constant tensors, the storage is +// non-owning. +class MaybeOwningStorage { + public: + MaybeOwningStorage(const standalone::c10::Device& device, size_t nbytes) + : device_(device), capacity_(nbytes), is_owning_(true) { + // Allocating memory here so owning_ has to be true. + if (device.is_cpu()) { + data_ = DeviceTraits::allocate( + nbytes, device); + deleter_ = DeviceTraits::free; + } else if (device.is_cuda()) { + data_ = DeviceTraits::allocate( + nbytes, device); + deleter_ = DeviceTraits::free; + } else { + STANDALONE_CHECK(false, "Unsupported device type"); + } + } + + MaybeOwningStorage( + const standalone::c10::Device& device, + void* data, + size_t nbytes) + : device_(device), data_(data), capacity_(nbytes), is_owning_(false) { + // data pointer is not owned by this object + } + + MaybeOwningStorage() = delete; + MaybeOwningStorage& operator=(const MaybeOwningStorage&) = delete; + MaybeOwningStorage(const MaybeOwningStorage&) = delete; + + // Move constructor + MaybeOwningStorage(MaybeOwningStorage&& other) noexcept + : device_(other.device_), + data_(other.data_), + capacity_(other.capacity_), + deleter_(other.deleter_), + is_owning_(other.is_owning_) { + // Leave the moved-from object in a safe state + other.data_ = nullptr; + other.capacity_ = 0; + other.deleter_ = detail::noop; + other.is_owning_ = false; + } + + // Move assignment operator + MaybeOwningStorage& operator=(MaybeOwningStorage&& other) noexcept { + if (this != &other) { + // Free current resources + free_data(); + + // Transfer ownership from other + device_ = other.device_; + data_ = other.data_; + capacity_ = other.capacity_; + deleter_ = other.deleter_; + is_owning_ = other.is_owning_; + + // Leave the moved-from object in a safe state + other.data_ = nullptr; + other.capacity_ = 0; + other.deleter_ = detail::noop; + other.is_owning_ = false; + } + return *this; + } + + ~MaybeOwningStorage() { + free_data(); + } + + void copy_( + void* dst_data_ptr, + void* src_data_ptr, + size_t nbytes, + const standalone::c10::Device& src_device) { + STANDALONE_CHECK( + dst_data_ptr, "Storage clone failed: dst_data_ptr can not be nullptr") + STANDALONE_CHECK( + src_data_ptr, "Storage clone failed: src_data_ptr can not be nullptr") + if (dst_data_ptr == src_data_ptr) { + return; + } + + if (device_.is_cpu() && src_device.is_cpu()) { + // CPU to CPU copy + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } else { + // At least one of the devices is CUDA + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } + } + + MaybeOwningStorage clone(const standalone::c10::Device& device) const { + STANDALONE_CHECK( + data_, "Storage clone failed: source data can not be nullptr") + // Create a new owning storage with the specified device and same capacity + MaybeOwningStorage cloned_storage(device, capacity_); + + // Copy the data from the current storage to the new storage + if (device_.is_cpu() && device.is_cpu()) { + // CPU to CPU copy + DeviceTraits::memcpy( + cloned_storage.data_, data_, capacity_, device, device_); + } else { + // At least one of the devices is CUDA + DeviceTraits::memcpy( + cloned_storage.data_, data_, capacity_, device, device_); + } + + return cloned_storage; + } + + void* data() const { + // Always return nullptr for zero-sized storage + if (capacity_ == 0) { + return nullptr; + } + return data_; + } + + const standalone::c10::Device& device() const { + return device_; + } + + size_t nbytes() const { + return this->capacity_; + } + + void unsafe_set_to_non_owning() { + // This is only used when interacting with at::Tensor. When testing + // standalone AOTI from pytorch, we need to convert the output SlimTensor + // into at::Tensor, which means the storage ownership should be stolen by + // at::Tensor. When all the SlimTensors referencing the storage are + // destroyed, the storage should NOT be freed. + deleter_ = detail::noop; + is_owning_ = false; + } + + bool is_resizable() const { + return is_owning_; + } + + void free_data() { + if (data_ != nullptr) { + deleter_(data_); + } + } + + void set_data_ptr_noswap(void* new_data) { + data_ = new_data; + } + + void set_nbytes(size_t new_nbytes) { + capacity_ = new_nbytes; + } + + private: + standalone::c10::Device device_ = CPU_DEVICE; + void* data_ = nullptr; + size_t capacity_ = 0; + DeleterFn deleter_ = detail::noop; + bool is_owning_ = false; +}; + +using Storage = SharedPtr; + +inline Storage new_storage( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + size_t nbytes = compute_storage_nbytes( + sizes, strides, standalone::c10::elementSize(dtype), 0); + return Storage(new MaybeOwningStorage(device, nbytes)); +} +} // namespace standalone::slim diff --git a/backends/aoti/slim/cuda/Exception.h b/backends/aoti/slim/cuda/Exception.h new file mode 100644 index 00000000000..d777352c1d7 --- /dev/null +++ b/backends/aoti/slim/cuda/Exception.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_CUDA + +#include +#include +#include + +#include +#include +#include + +#include +#include + +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + ET_CHECK_MSG(__err == cudaSuccess, "%s", cudaGetErrorString(__err)); \ + } while (0) + +#define ET_CUDA_CHECK_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (ET_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + ET_LOG(Warning, "CUDA warning: %s", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif // USE_CUDA diff --git a/backends/aoti/slim/cuda/Guard.h b/backends/aoti/slim/cuda/Guard.h new file mode 100644 index 00000000000..c9b2441b148 --- /dev/null +++ b/backends/aoti/slim/cuda/Guard.h @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace standalone::slim::cuda { + +// Thread-local stream management +namespace detail { +inline thread_local std:: + unordered_map + current_streams_; +} + +/// Set the current CUDA stream for the specified device +inline void setCurrentCUDAStream( + cudaStream_t stream, + standalone::c10::DeviceIndex device_index = -1) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + STANDALONE_CUDA_CHECK(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + detail::current_streams_[device_index] = stream; +} + +/// Get the current CUDA stream for the specified device +inline cudaStream_t getCurrentCUDAStream( + standalone::c10::DeviceIndex device_index = -1) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + STANDALONE_CUDA_CHECK(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + auto it = detail::current_streams_.find(device_index); + if (it != detail::current_streams_.end()) { + return it->second; + } + + // Create a new stream and set it as current + cudaStream_t stream; + STANDALONE_CUDA_CHECK(cudaStreamCreate(&stream)); + setCurrentCUDAStream(stream, device_index); + return stream; +} + +struct CUDAGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit CUDAGuard() = delete; + + /// Set the current CUDA device to the passed device index. + explicit CUDAGuard(standalone::c10::DeviceIndex device_index) { + set_index(device_index); + } + + /// Sets the current CUDA device to the passed device. Errors if the passed + /// device is not a CUDA device. + explicit CUDAGuard(standalone::c10::Device device) { + STANDALONE_CHECK( + device.is_cuda(), + "Expected a CUDA device for CUDAGuard, but got ", + device); + set_index(device.index()); + } + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + CUDAGuard(CUDAGuard&& other) = delete; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + ~CUDAGuard() { + // Restore the original device if necessary + if (original_device_index_ != current_device_index_) { + STANDALONE_CUDA_CHECK_WARN(cudaSetDevice(original_device_index_)); + } + } + + /// Sets the CUDA device to the given device index. + void set_index(standalone::c10::DeviceIndex device_index) { + int orig_index = -1; + STANDALONE_CUDA_CHECK(cudaGetDevice(&orig_index)); + + original_device_index_ = orig_index; + current_device_index_ = device_index; + if (current_device_index_ != original_device_index_) { + STANDALONE_CUDA_CHECK(cudaSetDevice(current_device_index_)); + } + } + + private: + /// The guard for the current device. + standalone::c10::DeviceIndex original_device_index_; + standalone::c10::DeviceIndex current_device_index_; +}; + +struct CUDAStreamGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit CUDAStreamGuard() = delete; + + /// Set the current CUDA stream to the passed stream on the specified device. + explicit CUDAStreamGuard( + cudaStream_t stream, + standalone::c10::DeviceIndex device_index) + : device_guard_(device_index) { + set_stream(stream, device_index); + } + + // Copy is not allowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + CUDAStreamGuard(CUDAStreamGuard&& other) = delete; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; + + ~CUDAStreamGuard() { + // Restore the original stream for the device + setCurrentCUDAStream(original_stream_, device_index_); + // Device guard will automatically restore the original device + } + + /// Sets the CUDA stream to the given stream on the specified device. + void set_stream( + cudaStream_t stream, + standalone::c10::DeviceIndex device_index) { + // Store the original stream for this device + original_stream_ = getCurrentCUDAStream(device_index); + current_stream_ = stream; + device_index_ = device_index; + + // Set the new stream as current for this device + setCurrentCUDAStream(stream, device_index); + } + + /// Get the current guarded stream + cudaStream_t stream() const { + return current_stream_; + } + + /// Get the device index being guarded + standalone::c10::DeviceIndex device_index() const { + return device_index_; + } + + private: + /// The device guard that handles device switching + CUDAGuard device_guard_; + /// The original stream that was current before this guard + cudaStream_t original_stream_ = nullptr; + /// The current stream being guarded + cudaStream_t current_stream_ = nullptr; + /// The device index for this stream guard + standalone::c10::DeviceIndex device_index_; +}; + +} // namespace standalone::slim::cuda diff --git a/backends/aoti/slim/factory/Empty.h b/backends/aoti/slim/factory/Empty.h new file mode 100644 index 00000000000..bbd4996b84c --- /dev/null +++ b/backends/aoti/slim/factory/Empty.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace standalone::slim { +// The returned SlimTensor owns the underlying storage +inline SlimTensor empty_strided( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + Storage storage = new_storage(sizes, strides, dtype, device); + return SlimTensor(std::move(storage), sizes, strides, dtype, 0); +} + +inline SlimTensor empty( + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + std::vector contig_strides = + standalone::slim::compute_contiguous_strides(sizes); + Storage storage = new_storage(sizes, contig_strides, dtype, device); + return SlimTensor(std::move(storage), sizes, contig_strides, dtype, 0); +} + +inline SlimTensor empty_like(const SlimTensor& other) { + return empty_strided( + other.sizes(), other.strides(), other.dtype(), other.device()); +} +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/Factory.h b/backends/aoti/slim/factory/Factory.h new file mode 100644 index 00000000000..5e172bc9f6a --- /dev/null +++ b/backends/aoti/slim/factory/Factory.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace standalone::slim { +inline SlimTensor zeros( + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + SlimTensor tensor = empty(sizes, dtype, device); + tensor.fill_(standalone::c10::Scalar(0)); + return tensor; +} + +inline SlimTensor zeros_like(const SlimTensor& other) { + return zeros(other.sizes(), other.dtype(), other.device()); +} + +inline SlimTensor ones( + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + SlimTensor tensor = empty(sizes, dtype, device); + tensor.fill_(standalone::c10::Scalar(1)); + return tensor; +} + +inline SlimTensor ones_like(const SlimTensor& other) { + return ones(other.sizes(), other.dtype(), other.device()); +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/FromBlob.h b/backends/aoti/slim/factory/FromBlob.h new file mode 100644 index 00000000000..d1877f7f31d --- /dev/null +++ b/backends/aoti/slim/factory/FromBlob.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace standalone::slim { + +// The returned SlimTensor does not own the underlying storage +inline SlimTensor from_blob( + void* data, + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE, + int64_t storage_offset = 0) { + STANDALONE_CHECK(data != nullptr, "data pointer can not be nullptr"); + + Storage storage(new MaybeOwningStorage( + device, + data, + compute_storage_nbytes( + sizes, strides, elementSize(dtype), storage_offset))); + return SlimTensor(std::move(storage), sizes, strides, dtype, storage_offset); +} + +inline SlimTensor from_blob( + void* data, + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE, + int64_t storage_offset = 0) { + std::vector contig_strides = + standalone::slim::compute_contiguous_strides(sizes); + return from_blob(data, sizes, contig_strides, dtype, device, storage_offset); +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/FromScalar.h b/backends/aoti/slim/factory/FromScalar.h new file mode 100644 index 00000000000..223f734d940 --- /dev/null +++ b/backends/aoti/slim/factory/FromScalar.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace standalone::slim { + +inline SlimTensor scalar_to_tensor( + const standalone::c10::Scalar& s, + const standalone::c10::Device& device = CPU_DEVICE) { + SlimTensor result = empty_strided({}, {}, s.type(), device); + result.fill_(s); + return result; +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/Pad.h b/backends/aoti/slim/factory/Pad.h new file mode 100644 index 00000000000..4d7fef731bd --- /dev/null +++ b/backends/aoti/slim/factory/Pad.h @@ -0,0 +1,106 @@ +#pragma once + +#include + +namespace standalone::slim { + +inline SlimTensor constant_pad_nd( + const SlimTensor& self, + standalone::c10::IntArrayRef pad, + const standalone::c10::Scalar& value) { + STANDALONE_CHECK(pad.size() % 2 == 0, "Length of pad must be even"); + + standalone::c10::IntArrayRef input_sizes = self.sizes(); + int64_t l_inp = self.dim(); + int64_t l_pad = static_cast(pad.size()) / 2; + int64_t l_diff = l_inp - l_pad; + + STANDALONE_CHECK( + l_pad <= l_inp, + "Length of pad should be no more than twice the input's dimension."); + + bool all_pads_non_positive = true; + SlimTensor c_input = self; + for (int64_t i = l_diff; i < l_inp; i++) { + int64_t pad_idx = 2 * (l_inp - i - 1); + + if (pad[pad_idx] < 0) { + c_input = + c_input.narrow(i, -pad[pad_idx], c_input.size(i) + pad[pad_idx]); + } else if (pad[pad_idx] != 0) { + all_pads_non_positive = false; + } + if (pad[pad_idx + 1] < 0) { + c_input = c_input.narrow(i, 0, c_input.size(i) + pad[pad_idx + 1]); + } else if (pad[pad_idx + 1] != 0) { + all_pads_non_positive = false; + } + } + + // if none of the pads are positive we can optimize and just return the result + // of calling .narrow() on the input + if (all_pads_non_positive) { + return c_input.clone_contiguous(); + } + + // calculate the new shape for the output tensor + std::vector new_shape; + new_shape.reserve(l_diff); + for (int64_t i = 0; i < l_diff; i++) { + new_shape.emplace_back(input_sizes[i]); + } + + for (const auto i : standalone::c10::irange((size_t)l_pad)) { + auto pad_idx = pad.size() - ((i + 1) * 2); + auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; + STANDALONE_CHECK( + new_dim > 0, + "The input size ", + input_sizes[l_diff + i], + ", plus negative padding ", + pad[pad_idx], + " and ", + pad[pad_idx + 1], + " resulted in a negative output size, " + "which is invalid. Check dimension ", + l_diff + i, + " of your input."); + new_shape.emplace_back(new_dim); + } + + SlimTensor output = empty(new_shape, self.dtype(), self.device()); + output.fill_(value); + + // create a view into the center of the output tensor + SlimTensor c_output = output; + for (const auto i : standalone::c10::irange(l_diff, l_inp)) { + auto pad_idx = 2 * (l_inp - i - 1); + if (pad[pad_idx] > 0) { + c_output = + c_output.narrow(i, pad[pad_idx], c_output.size(i) - pad[pad_idx]); + } + if (pad[pad_idx + 1] > 0) { + c_output = c_output.narrow(i, 0, c_output.size(i) - pad[pad_idx + 1]); + } + } + // copy the input data into the center view + c_output.copy_(c_input); + return output; +} + +inline SlimTensor pad( + const SlimTensor& self, + standalone::c10::IntArrayRef pad, + std::string_view mode, + std::optional value) { + if (mode == "constant") { + return constant_pad_nd(self, pad, value.value_or(0.0)); + } + STANDALONE_CHECK( + false, + "Unsupported padding mode: ", + mode, + ". Only constant mode is available."); +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/targets.bzl b/backends/aoti/slim/targets.bzl new file mode 100644 index 00000000000..62db9452984 --- /dev/null +++ b/backends/aoti/slim/targets.bzl @@ -0,0 +1,81 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define SlimTensor library targets. + + SlimTensor is a lightweight tensor implementation for AOTI (Ahead-of-Time Inference) + that provides a minimal, efficient tensor abstraction for ExecuTorch CUDA backend. + + This is a direct port from torchnative/standalone/slim with minimal modifications. + """ + + # Utility library (SharedPtr, SizeUtil) + runtime.cxx_library( + name = "util", + exported_headers = glob(["util/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/backends/aoti/slim/c10:c10", + ], + ) + + # Core SlimTensor library (CPU only) + runtime.cxx_library( + name = "core", + exported_headers = glob(["core/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":util", + "//executorch/backends/aoti/slim/c10:c10", + ], + ) + + # Factory functions library + runtime.cxx_library( + name = "factory", + exported_headers = glob(["factory/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":core", + "//executorch/backends/aoti/slim/c10:c10", + ], + ) + + # CUDA support library + runtime.cxx_library( + name = "cuda", + exported_headers = glob(["cuda/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_preprocessor_flags = ["-DUSE_CUDA"], + exported_deps = [ + ":core", + "//executorch/backends/aoti/slim/c10:c10", + "//executorch/backends/aoti/slim/c10:c10_cuda", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + + # CPU-only SlimTensor library (no CUDA dependencies) + runtime.cxx_library( + name = "slim_tensor_cpu", + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":core", + ":factory", + ":util", + ], + ) + + # Full SlimTensor library (with CUDA support) + runtime.cxx_library( + name = "slim_tensor", + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":core", + ":factory", + ":cuda", + ":util", + ], + ) diff --git a/backends/aoti/slim/tests/TARGETS b/backends/aoti/slim/tests/TARGETS new file mode 100644 index 00000000000..f91c46c0f20 --- /dev/null +++ b/backends/aoti/slim/tests/TARGETS @@ -0,0 +1,5 @@ +load("targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/tests/targets.bzl b/backends/aoti/slim/tests/targets.bzl new file mode 100644 index 00000000000..0f0eb843c7d --- /dev/null +++ b/backends/aoti/slim/tests/targets.bzl @@ -0,0 +1,31 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") + +def slim_tensor_cpp_unittest(name, extra_deps = []): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti/slim:slim_tensor_cpu", + ] + extra_deps, + ) + +def slim_tensor_cuda_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti/slim:slim_tensor", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Define test targets for SlimTensor library.""" + slim_tensor_cpp_unittest("slim_tensor_basic") + slim_tensor_cuda_cpp_unittest("slim_tensor_cuda") diff --git a/backends/aoti/slim/tests/test_slim_tensor_basic.cpp b/backends/aoti/slim/tests/test_slim_tensor_basic.cpp new file mode 100644 index 00000000000..37b6ccb240d --- /dev/null +++ b/backends/aoti/slim/tests/test_slim_tensor_basic.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace standalone::slim { +namespace { + +TEST(SlimTensorBasicTest, EmptyTensorCreation) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.dim(), 3); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.size(2), 4); + EXPECT_EQ(tensor.numel(), 24); + EXPECT_EQ(tensor.dtype(), standalone::c10::ScalarType::Float); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TEST(SlimTensorBasicTest, EmptyTensorContiguousStrides) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.stride(0), 12); + EXPECT_EQ(tensor.stride(1), 4); + EXPECT_EQ(tensor.stride(2), 1); +} + +TEST(SlimTensorBasicTest, ZerosTensorCreation) { + auto tensor = zeros({3, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.numel(), 9); + float* data = static_cast(tensor.data_ptr()); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(data[i], 0.0f); + } +} + +TEST(SlimTensorBasicTest, OnesTensorCreation) { + auto tensor = ones({2, 2}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.numel(), 4); + float* data = static_cast(tensor.data_ptr()); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(data[i], 1.0f); + } +} + +TEST(SlimTensorBasicTest, FillTensor) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + tensor.fill_(5.0f); + float* data = static_cast(tensor.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(data[i], 5.0f); + } +} + +TEST(SlimTensorBasicTest, FromBlobNonOwning) { + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor = from_blob( + data.data(), {2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.dim(), 2); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.numel(), 6); + EXPECT_EQ(tensor.data_ptr(), data.data()); +} + +TEST(SlimTensorBasicTest, Clone) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + tensor.fill_(3.14f); + + auto cloned = tensor.clone(); + EXPECT_NE(cloned.data_ptr(), tensor.data_ptr()); + EXPECT_EQ(cloned.sizes(), tensor.sizes()); + EXPECT_EQ(cloned.strides(), tensor.strides()); + + float* cloned_data = static_cast(cloned.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(cloned_data[i], 3.14f); + } +} + +TEST(SlimTensorBasicTest, CopyFrom) { + auto src = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + src.fill_(2.5f); + + auto dst = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + dst.copy_(src); + + float* dst_data = static_cast(dst.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(dst_data[i], 2.5f); + } +} + +TEST(SlimTensorBasicTest, Reshape) { + auto tensor = empty({2, 6}, standalone::c10::ScalarType::Float, CPU_DEVICE); + tensor.fill_(1.0f); + + auto reshaped = tensor.reshape({3, 4}); + EXPECT_EQ(reshaped.dim(), 2); + EXPECT_EQ(reshaped.size(0), 3); + EXPECT_EQ(reshaped.size(1), 4); + EXPECT_EQ(reshaped.numel(), 12); +} + +TEST(SlimTensorBasicTest, Transpose) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto transposed = tensor.transpose(0, 1); + EXPECT_EQ(transposed.size(0), 3); + EXPECT_EQ(transposed.size(1), 2); +} + +TEST(SlimTensorBasicTest, Permute) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto permuted = tensor.permute({2, 0, 1}); + EXPECT_EQ(permuted.size(0), 4); + EXPECT_EQ(permuted.size(1), 2); + EXPECT_EQ(permuted.size(2), 3); +} + +TEST(SlimTensorBasicTest, Narrow) { + auto tensor = empty({10}, standalone::c10::ScalarType::Float, CPU_DEVICE); + for (int i = 0; i < 10; ++i) { + static_cast(tensor.data_ptr())[i] = static_cast(i); + } + + auto narrowed = tensor.narrow(0, 2, 5); + EXPECT_EQ(narrowed.dim(), 1); + EXPECT_EQ(narrowed.size(0), 5); + + float* narrowed_data = static_cast(narrowed.data_ptr()); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(narrowed_data[i], static_cast(i + 2)); + } +} + +TEST(SlimTensorBasicTest, EmptyLike) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto empty_like_tensor = empty_like(tensor); + EXPECT_EQ(empty_like_tensor.sizes(), tensor.sizes()); + EXPECT_EQ(empty_like_tensor.dtype(), tensor.dtype()); + EXPECT_EQ(empty_like_tensor.device(), tensor.device()); +} + +TEST(SlimTensorBasicTest, ZerosLike) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto zeros_tensor = zeros_like(tensor); + EXPECT_EQ(zeros_tensor.sizes(), tensor.sizes()); + + float* data = static_cast(zeros_tensor.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(data[i], 0.0f); + } +} + +} // namespace +} // namespace standalone::slim diff --git a/backends/aoti/slim/tests/test_slim_tensor_cuda.cpp b/backends/aoti/slim/tests/test_slim_tensor_cuda.cpp new file mode 100644 index 00000000000..571d4f99893 --- /dev/null +++ b/backends/aoti/slim/tests/test_slim_tensor_cuda.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace standalone::slim { +namespace { + +class SlimTensorCUDATest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA device not available"; + } + } +}; + +TEST_F(SlimTensorCUDATest, EmptyCUDATensorCreation) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.dim(), 3); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.size(2), 4); + EXPECT_EQ(tensor.numel(), 24); + EXPECT_EQ(tensor.device().type(), standalone::c10::DeviceType::CUDA); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TEST_F(SlimTensorCUDATest, ZerosCUDATensor) { + auto tensor = + zeros({3, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.numel(), 9); + EXPECT_EQ(tensor.device().type(), standalone::c10::DeviceType::CUDA); + + std::vector host_data(9); + cudaMemcpy( + host_data.data(), + tensor.data_ptr(), + 9 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(host_data[i], 0.0f); + } +} + +TEST_F(SlimTensorCUDATest, OnesCUDATensor) { + auto tensor = + ones({2, 2}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.numel(), 4); + + std::vector host_data(4); + cudaMemcpy( + host_data.data(), + tensor.data_ptr(), + 4 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(host_data[i], 1.0f); + } +} + +TEST_F(SlimTensorCUDATest, FillCUDATensor) { + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + tensor.fill_(5.0f); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + tensor.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(host_data[i], 5.0f); + } +} + +TEST_F(SlimTensorCUDATest, CloneCUDATensor) { + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + tensor.fill_(3.14f); + + auto cloned = tensor.clone(); + EXPECT_NE(cloned.data_ptr(), tensor.data_ptr()); + EXPECT_EQ(cloned.sizes(), tensor.sizes()); + EXPECT_EQ(cloned.device(), tensor.device()); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + cloned.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_FLOAT_EQ(host_data[i], 3.14f); + } +} + +TEST_F(SlimTensorCUDATest, CopyCUDAToCUDA) { + auto src = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + src.fill_(2.5f); + + auto dst = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + dst.copy_(src); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + dst.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(host_data[i], 2.5f); + } +} + +TEST_F(SlimTensorCUDATest, CopyCPUToCUDA) { + auto cpu_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + cpu_tensor.fill_(1.5f); + + auto cuda_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_tensor); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + cuda_tensor.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(host_data[i], 1.5f); + } +} + +TEST_F(SlimTensorCUDATest, CopyCUDAToCPU) { + auto cuda_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.fill_(4.5f); + + auto cpu_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + cpu_tensor.copy_(cuda_tensor); + + float* data = static_cast(cpu_tensor.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(data[i], 4.5f); + } +} + +TEST_F(SlimTensorCUDATest, CUDAGuard) { + cuda::CUDAGuard guard(0); + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.device().type(), standalone::c10::DeviceType::CUDA); +} + +TEST_F(SlimTensorCUDATest, ReshapeCUDATensor) { + auto tensor = + empty({2, 6}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + auto reshaped = tensor.reshape({3, 4}); + EXPECT_EQ(reshaped.dim(), 2); + EXPECT_EQ(reshaped.size(0), 3); + EXPECT_EQ(reshaped.size(1), 4); + EXPECT_EQ(reshaped.device(), tensor.device()); +} + +TEST_F(SlimTensorCUDATest, TransposeCUDATensor) { + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + auto transposed = tensor.transpose(0, 1); + EXPECT_EQ(transposed.size(0), 3); + EXPECT_EQ(transposed.size(1), 2); + EXPECT_EQ(transposed.device(), tensor.device()); +} + +TEST_F(SlimTensorCUDATest, PermuteCUDATensor) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + auto permuted = tensor.permute({2, 0, 1}); + EXPECT_EQ(permuted.size(0), 4); + EXPECT_EQ(permuted.size(1), 2); + EXPECT_EQ(permuted.size(2), 3); + EXPECT_EQ(permuted.device(), tensor.device()); +} + +} // namespace +} // namespace standalone::slim diff --git a/backends/aoti/slim/tests/test_type_convert.cpp b/backends/aoti/slim/tests/test_type_convert.cpp new file mode 100644 index 00000000000..a93c7d27d70 --- /dev/null +++ b/backends/aoti/slim/tests/test_type_convert.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace executorch::backends::aoti::slim { +namespace { + +TEST(TypeConvertTest, ToInt32Vec) { + std::vector int64_vec = {1, 2, 3, 4, 5}; + auto int32_vec = to_int32_vec(int64_vec); + + EXPECT_EQ(int32_vec.size(), 5); + EXPECT_EQ(int32_vec[0], 1); + EXPECT_EQ(int32_vec[1], 2); + EXPECT_EQ(int32_vec[2], 3); + EXPECT_EQ(int32_vec[3], 4); + EXPECT_EQ(int32_vec[4], 5); +} + +TEST(TypeConvertTest, ToInt64Vec) { + std::vector int32_vec = {10, 20, 30}; + auto int64_vec = to_int64_vec(int32_vec); + + EXPECT_EQ(int64_vec.size(), 3); + EXPECT_EQ(int64_vec[0], 10); + EXPECT_EQ(int64_vec[1], 20); + EXPECT_EQ(int64_vec[2], 30); +} + +TEST(TypeConvertTest, ToInt32VecEmpty) { + std::vector empty_vec; + auto result = to_int32_vec(empty_vec); + EXPECT_TRUE(result.empty()); +} + +TEST(TypeConvertTest, ToInt64VecEmpty) { + std::vector empty_vec; + auto result = to_int64_vec(empty_vec); + EXPECT_TRUE(result.empty()); +} + +TEST(TypeConvertTest, SafeNarrowInt64ToInt32) { + int64_t value = 42; + int32_t result = safe_narrow(value); + EXPECT_EQ(result, 42); +} + +TEST(TypeConvertTest, SafeNarrowInt32ToInt16) { + int32_t value = 1000; + int16_t result = safe_narrow(value); + EXPECT_EQ(result, 1000); +} + +TEST(TypeConvertTest, ToInt32VecLargeValues) { + std::vector int64_vec = {1000000, 2000000, 3000000}; + auto int32_vec = to_int32_vec(int64_vec); + + EXPECT_EQ(int32_vec.size(), 3); + EXPECT_EQ(int32_vec[0], 1000000); + EXPECT_EQ(int32_vec[1], 2000000); + EXPECT_EQ(int32_vec[2], 3000000); +} + +TEST(TypeConvertTest, ToInt64VecFromUint32) { + std::vector uint32_vec = {100, 200, 300}; + auto int64_vec = to_int64_vec(uint32_vec); + + EXPECT_EQ(int64_vec.size(), 3); + EXPECT_EQ(int64_vec[0], 100); + EXPECT_EQ(int64_vec[1], 200); + EXPECT_EQ(int64_vec[2], 300); +} + +} // namespace +} // namespace executorch::backends::aoti::slim diff --git a/backends/aoti/slim/util/SharedPtr.h b/backends/aoti/slim/util/SharedPtr.h new file mode 100644 index 00000000000..9ad565d9ab9 --- /dev/null +++ b/backends/aoti/slim/util/SharedPtr.h @@ -0,0 +1,222 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace standalone::slim { + +/** + * NonAtomicSharedPtr - A lightweight, non-thread-safe shared pointer + * implementation + * + * This class provides shared ownership semantics similar to std::shared_ptr but + * without atomic operations, making it faster in single-threaded contexts where + * thread safety is not required. + * + * Primary Use Cases: + * 1. Intermediate SlimTensor Storage Management: + * - Manages temporary tensors created during model execution + * - These tensors are confined to single-threaded execution contexts + * - Avoids the overhead of atomic reference counting in std::shared_ptr + * + * 2. Input/Output Tensor References: + * - Provides reference counting for input/output tensors + * - Tensor lifetimes are externally managed (not by AOTI-generated code) + * - Uses dummy deleters to prevent premature deallocation + * - Reference counting still occurs but actual cleanup is deferred + * + * Performance Benefits: + * - Non-atomic reference counting reduces CPU overhead + * - Smaller memory footprint compared to std::shared_ptr + * - Optimized for single-threaded tensor operations + * + * Thread Safety: NOT THREAD-SAFE + * - Must only be used in single-threaded contexts + * - Concurrent access will result in undefined behavior + * - Define the USE_MULTI_THREAD macro to use std::shared_ptr instead when + * thread safety is required + */ +template +class NonAtomicSharedPtr { + private: + struct ControlBlock { + int count = 1; + T* ptr; + using Deleter = void (*)(T*); + Deleter deleter; + + ControlBlock(T* p, Deleter d) : ptr(p), deleter(d) {} + ControlBlock(const ControlBlock&) = delete; + ControlBlock& operator=(const ControlBlock&) = delete; + ControlBlock(ControlBlock&&) = delete; + ControlBlock& operator=(ControlBlock&&) = delete; + + ~ControlBlock() { + if (ptr) { + deleter(ptr); + } + } + }; + + ControlBlock* cb_; + + static void default_deleter(T* p) { + delete p; + } + + void cleanup() { + if (cb_ && --cb_->count == 0) { + delete cb_; + } + cb_ = nullptr; + } + + public: + // Default constructor + NonAtomicSharedPtr() noexcept : cb_(nullptr) {} + + // Constructor from raw pointer + explicit NonAtomicSharedPtr( + T* p, + typename ControlBlock::Deleter d = default_deleter) + : cb_(p ? new ControlBlock(p, d) : nullptr) {} + + // Copy constructor + NonAtomicSharedPtr(const NonAtomicSharedPtr& other) noexcept + : cb_(other.cb_) { + if (cb_) { + ++cb_->count; + } + } + + // Move constructor + NonAtomicSharedPtr(NonAtomicSharedPtr&& other) noexcept : cb_(other.cb_) { + other.cb_ = nullptr; + } + + // Destructor + ~NonAtomicSharedPtr() { + cleanup(); + } + + // Copy assignment + NonAtomicSharedPtr& operator=(const NonAtomicSharedPtr& other) noexcept { + if (this != &other) { + cleanup(); + cb_ = other.cb_; + if (cb_) { + ++cb_->count; + } + } + return *this; + } + + // Move assignment + NonAtomicSharedPtr& operator=(NonAtomicSharedPtr&& other) noexcept { + if (this != &other) { + cleanup(); + cb_ = other.cb_; + other.cb_ = nullptr; + } + return *this; + } + + // Modifiers + void reset( + T* p = nullptr, + typename ControlBlock::Deleter d = default_deleter) { + *this = NonAtomicSharedPtr(p, d); + } + + void swap(NonAtomicSharedPtr& other) noexcept { + std::swap(cb_, other.cb_); + } + + // Observers + T* get() const noexcept { + return cb_ ? cb_->ptr : nullptr; + } + T& operator*() const { + STANDALONE_CHECK(cb_, "Dereferencing null NonAtomicSharedPtr"); + return *cb_->ptr; + } + T* operator->() const { + STANDALONE_CHECK(cb_, "Accessing member of null NonAtomicSharedPtr"); + return cb_->ptr; + } + long use_count() const noexcept { + return cb_ ? cb_->count : 0; + } + explicit operator bool() const noexcept { + return cb_ != nullptr; + } + + // Friend swap for ADL + friend void swap(NonAtomicSharedPtr& a, NonAtomicSharedPtr& b) noexcept { + a.swap(b); + } + + // Comparison operators + friend bool operator==( + const NonAtomicSharedPtr& lhs, + const NonAtomicSharedPtr& rhs) noexcept { + return lhs.get() == rhs.get(); + } + + friend bool operator!=( + const NonAtomicSharedPtr& lhs, + const NonAtomicSharedPtr& rhs) noexcept { + return !(lhs == rhs); + } + + friend bool operator==( + const NonAtomicSharedPtr& lhs, + std::nullptr_t) noexcept { + return lhs.get() == nullptr; + } + + friend bool operator!=( + const NonAtomicSharedPtr& lhs, + std::nullptr_t) noexcept { + return lhs.get() != nullptr; + } + + friend bool operator==( + std::nullptr_t, + const NonAtomicSharedPtr& rhs) noexcept { + return rhs.get() == nullptr; + } + + friend bool operator!=( + std::nullptr_t, + const NonAtomicSharedPtr& rhs) noexcept { + return rhs.get() != nullptr; + } +}; + +#ifdef USE_MULTI_THREAD +template +using SharedPtr = ::std::shared_ptr; + +// make_shared for std::shared_ptr +template +std::shared_ptr make_shared(Args&&... args) { + return std::make_shared(std::forward(args)...); +} + +#else +template +using SharedPtr = ::standalone::slim::NonAtomicSharedPtr; + +// make_shared for NonAtomicSharedPtr +template +NonAtomicSharedPtr make_shared(Args&&... args) { + return NonAtomicSharedPtr(new T(std::forward(args)...)); +} + +#endif // USE_MULTI_THREAD +} // namespace standalone::slim diff --git a/backends/aoti/slim/util/SizeUtil.h b/backends/aoti/slim/util/SizeUtil.h new file mode 100644 index 00000000000..d22416cd176 --- /dev/null +++ b/backends/aoti/slim/util/SizeUtil.h @@ -0,0 +1,283 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace standalone::slim { +#ifndef STANDALONE_MOBILE +inline constexpr uint64_t storage_max() { + // int64_t and size_t are used somewhat inconsistently throughout ATen. + // To be safe, storage size calculations must fit in both types. + constexpr auto int64_max = + static_cast(std::numeric_limits::max()); + constexpr auto size_max = + static_cast(std::numeric_limits::max()); + return std::min(int64_max, size_max); +} + +/** + * Compute the number of elements based on the sizes of a + * tensor. Catches integer overflow that may occur when a tensor + * using a sparse layout has multiple dimensions with large sizes. + */ +inline int64_t safe_compute_numel(standalone::c10::IntArrayRef sizes) { + uint64_t n = 1; + bool overflowed = standalone::c10::safe_multiplies_u64(sizes, &n); + overflowed |= (n > storage_max()); + STANDALONE_CHECK(!overflowed, "numel: integer multiplication overflow"); + return static_cast(n); +} + +inline std::vector safe_compute_contiguous_strides( + c10::IntArrayRef sizes) { + int64_t ndim = static_cast(sizes.size()); + std::vector strides(ndim); + if (ndim > 0) { + uint64_t stride = 1; + bool overflowed = false; + for (int64_t i = ndim - 1; i >= 0; i--) { + strides[i] = static_cast(stride); + if (sizes[i] != 0) { + uint64_t new_stride = 0; + overflowed |= c10::mul_overflows( + stride, static_cast(sizes[i]), &new_stride); + stride = new_stride; + } + } + STANDALONE_CHECK( + !overflowed, "contiguous_strides: stride multiplication overflow"); + } + return strides; +} +#endif // STANDALONE_MOBILE + +inline int64_t compute_numel(standalone::c10::IntArrayRef sizes) { +#ifndef STANDALONE_MOBILE + // Use overflow checks if supported by the compiler + return safe_compute_numel(sizes); +#else + return standalone::c10::multiply_integers(sizes); +#endif +} + +// named computeStorageNbytesContiguous in c10 +inline size_t compute_storage_nbytes_contiguous( + standalone::c10::IntArrayRef sizes, + size_t itemsize_bytes, + size_t storage_offset) { +// Ignore overflow checks on mobile +#ifndef STANDALONE_MOBILE + uint64_t size = 1; + bool overflowed = standalone::c10::safe_multiplies_u64(sizes, &size); + overflowed |= standalone::c10::add_overflows(size, storage_offset, &size); + overflowed |= standalone::c10::mul_overflows(size, itemsize_bytes, &size); + overflowed |= size > storage_max(); + STANDALONE_CHECK( + !overflowed, "Storage size calculation overflowed with sizes=", sizes); + return static_cast(size); +#else + const auto numel = multiply_integers(sizes); + return itemsize_bytes * (storage_offset + numel); +#endif +} + +// named computeStorageNbytes in c10 +inline size_t compute_storage_nbytes( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + size_t itemsize_bytes, + size_t storage_offset) { + STANDALONE_CHECK( + sizes.size() == strides.size(), + "dimensionality of sizes (", + sizes.size(), + ") must match dimensionality of strides (", + strides.size(), + ")"); + +// Ignore overflow checks on mobile +#ifndef STANDALONE_MOBILE + // size of the underlying storage is 1 bigger than the offset + // of the last element according to stride + uint64_t size = storage_offset + 1; + bool overflowed = false; + for (const auto i : standalone::c10::irange(sizes.size())) { + if (sizes[i] == 0) { + return 0; + } + + uint64_t strided_size = 0; + overflowed |= + standalone::c10::mul_overflows(strides[i], sizes[i] - 1, &strided_size); + overflowed |= standalone::c10::add_overflows(size, strided_size, &size); + } + overflowed |= standalone::c10::mul_overflows(size, itemsize_bytes, &size); + overflowed |= size > storage_max(); + STANDALONE_CHECK( + !overflowed, + "Storage size calculation overflowed with sizes=", + sizes, + " and strides=", + strides); + return static_cast(size); +#else + // size of the underlying storage is 1 bigger than the offset + // of the last element according to stride + uint64_t size = 1; + for (const auto i : standalone::c10::irange(sizes.size())) { + if (sizes[i] == 0) { + return 0; + } + + size += strides[i] * (sizes[i] - 1); + } + return itemsize_bytes * (storage_offset + size); +#endif +} + +inline std::vector compute_contiguous_strides(c10::IntArrayRef sizes) { +#ifndef STANDALONE_MOBILE + return safe_compute_contiguous_strides(sizes); +#else + int64_t ndim = static_cast(sizes.size()); + std::vector strides(ndim); + if (ndim > 0) { + int64_t stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + if (sizes[i] != 0) { + stride *= sizes[i]; + } + } + } + return strides; +#endif +} + +// calculates the final concrete shape by also filling in at most one '-1' +// dimension. +inline std::vector infer_size( + standalone::c10::IntArrayRef shape, + int64_t numel) { + int64_t new_size = 1; + std::optional infer_dim; + std::vector result_shape; + result_shape.reserve(shape.size()); + + size_t ndim = shape.size(); + bool overflowed = false; + for (size_t dim = 0; dim < ndim; dim++) { + if (shape[dim] == -1) { + STANDALONE_CHECK( + !infer_dim.has_value(), "only one dimension can be inferred"); + infer_dim = dim; + result_shape.push_back(-1); // placeholder + } else { + STANDALONE_CHECK(shape[dim] >= 0, "invalid shape dimension ", shape[dim]); + overflowed |= + standalone::c10::mul_overflows(new_size, shape[dim], &new_size); + result_shape.push_back(shape[dim]); + } + } + STANDALONE_CHECK(!overflowed, "shape calculation overflowed"); + + if (infer_dim.has_value()) { + STANDALONE_CHECK( + new_size != 0, + "cannot reshape tensor of 0 elements into shape with -1"); + STANDALONE_CHECK( + numel % new_size == 0, "shape is invalid for input size ", numel); + result_shape[*infer_dim] = numel / new_size; + } else { + STANDALONE_CHECK( + numel == new_size, "shape is invalid for input of size ", numel); + } + return result_shape; +} + +// it determines if a reshape is possible as a view. +// If so, it returns the new strides +// If not, it returns an empty optional +inline std::optional> compute_stride( + standalone::c10::IntArrayRef old_sizes, + standalone::c10::IntArrayRef old_strides, + standalone::c10::IntArrayRef new_sizes) { + if (old_sizes.empty()) { + return std::vector(new_sizes.size(), 1); + } + + // NOTE: stride is arbitrary in the numel() == 0 case; + // to match NumPy behavior we copy the strides if the size matches, otherwise + // we use the stride as if it were computed via resize. + // This could perhaps be combined with the below code, but the complexity + // didn't seem worth it. + size_t numel = compute_numel(old_sizes); + if (numel == 0 && old_sizes == new_sizes) { + return old_strides.vec(); + } + + int64_t new_sizes_len = static_cast(new_sizes.size()); + std::vector new_strides(new_sizes_len); + if (numel == 0) { + for (int64_t view_d = new_sizes_len - 1; view_d >= 0; view_d--) { + if (view_d == new_sizes_len - 1) { + new_strides[view_d] = 1; + } else { + new_strides[view_d] = std::max(new_sizes[view_d + 1], 1) * + new_strides[view_d + 1]; + } + } + return new_strides; + } + + int64_t view_d = new_sizes_len - 1; + int64_t chunk_base_stride = old_strides.back(); + int64_t tensor_numel = 1; + int64_t view_numel = 1; + bool overflowed = false; + for (int64_t tensor_d = static_cast(old_sizes.size()) - 1; + tensor_d >= 0; + tensor_d--) { + // TODO: ask if this could lead to overflow by any chance? + // even if so, overflow is not handled in the aten implementation + overflowed |= standalone::c10::mul_overflows( + tensor_numel, old_sizes[tensor_d], &tensor_numel); + + bool is_chunk_end = (tensor_d == 0) || + (old_sizes[tensor_d - 1] != 1 && + old_strides[tensor_d - 1] != tensor_numel * chunk_base_stride); + + if (is_chunk_end) { + while (view_d >= 0 && + (view_numel < tensor_numel || new_sizes[view_d] == 1)) { + new_strides[view_d] = view_numel * chunk_base_stride; + view_numel *= new_sizes[view_d]; + view_d--; + } + if (view_numel != tensor_numel) { + return std::nullopt; // Not viewable + } + if (tensor_d > 0) { + chunk_base_stride = old_strides[tensor_d - 1]; + tensor_numel = 1; + view_numel = 1; + } + } + } + STANDALONE_CHECK(!overflowed, "overflowed while computing strides"); + + if (view_d != -1) { + return std::nullopt; // not viewable + } + return new_strides; +} + +} // namespace standalone::slim