Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions backends/aoti/slim/c10/core/Device.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <executorch/backends/aoti/slim/c10/core/DeviceType.h>
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
#include <executorch/backends/aoti/slim/c10/util/StringUtil.h>
#include <executorch/runtime/platform/assert.h>

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -70,12 +70,7 @@ inline DeviceType parse_type(const std::string& device_string) {
device_names.push_back(it.first);
}
}
STANDALONE_CHECK(
false,
"Expected one of ",
Join(", ", device_names),
" device type at start of device string: ",
device_string);
ET_CHECK_MSG(false, "Expected a valid device type at start of device string");
}
} // namespace detail

Expand Down Expand Up @@ -111,7 +106,7 @@ struct Device final {
/// where `cpu` or `cuda` specifies the device type, and
/// `:<device-index>` optionally specifies a device index.
/* implicit */ Device(const std::string& device_string) : Device(Type::CPU) {
STANDALONE_CHECK(!device_string.empty(), "Device string must not be empty");
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");

std::string device_name, device_index_str;
detail::DeviceStringParsingState pstate =
Expand Down Expand Up @@ -170,21 +165,14 @@ struct Device final {
(pstate == detail::DeviceStringParsingState::kINDEX_START &&
device_index_str.empty());

STANDALONE_CHECK(
!has_error, "Invalid device string: '", device_string, "'");
ET_CHECK_MSG(!has_error, "Invalid device string");

try {
if (!device_index_str.empty()) {
index_ = static_cast<DeviceIndex>(std::stoi(device_index_str));
}
} catch (const std::exception&) {
STANDALONE_CHECK(
false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
ET_CHECK_MSG(false, "Could not parse device index in device string");
}
type_ = detail::parse_type(device_name);
validate();
Expand Down Expand Up @@ -326,13 +314,13 @@ struct Device final {
// performance in micro-benchmarks.
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(
ET_DCHECK_MSG(
index_ >= -1,
"Device index must be -1 or non-negative, got ",
"Device index must be -1 or non-negative, got %d",
static_cast<int>(index_));
STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(
ET_DCHECK_MSG(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
"CPU device index must be -1 or zero, got %d",
static_cast<int>(index_));
}
};
Expand Down
4 changes: 2 additions & 2 deletions backends/aoti/slim/c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <stdexcept>
#include <string>

#include <executorch/backends/aoti/slim/c10/util/Exception.h>
#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::aoti::slim::c10 {
enum class DeviceType : int8_t {
Expand Down Expand Up @@ -94,7 +94,7 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {

int idx = static_cast<int>(d);
if (idx < 0 || idx >= COMPILE_TIME_MAX_DEVICE_TYPES) {
STANDALONE_CHECK(false, "Unknown device: ", static_cast<int16_t>(d));
ET_CHECK_MSG(false, "Unknown device");
}
if (d == DeviceType::PrivateUse1) {
return get_privateuse1_backend(lower_case);
Expand Down
4 changes: 2 additions & 2 deletions backends/aoti/slim/c10/core/Layout.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <executorch/backends/aoti/slim/c10/util/Exception.h>
#include <executorch/runtime/platform/assert.h>

#include <cstdint>
#include <ostream>
Expand Down Expand Up @@ -46,7 +46,7 @@ inline std::ostream& operator<<(std::ostream& stream, c10::Layout layout) {
case c10::kJagged:
return stream << "Jagged";
default:
STANDALONE_CHECK(false, "Unknown layout");
ET_CHECK_MSG(false, "Unknown layout");
}
}

Expand Down
10 changes: 4 additions & 6 deletions backends/aoti/slim/c10/core/MemoryFormat.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <executorch/backends/aoti/slim/c10/util/ArrayRef.h>
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
#include <executorch/runtime/platform/assert.h>

#include <cstdint>
#include <ostream>
Expand Down Expand Up @@ -57,7 +57,7 @@ inline std::ostream& operator<<(
case MemoryFormat::ChannelsLast3d:
return stream << "ChannelsLast3d";
default:
STANDALONE_CHECK(false, "Unknown memory format ", memory_format);
ET_CHECK_MSG(false, "Unknown memory format");
}
}

Expand All @@ -79,8 +79,7 @@ inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
strides[1] = strides[2] * sizes[2];
return strides;
default:
STANDALONE_INTERNAL_ASSERT(
false, "ChannelsLast2d doesn't support size ", sizes.size());
ET_DCHECK_MSG(false, "ChannelsLast2d doesn't support this size");
}
}

Expand All @@ -106,8 +105,7 @@ std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
strides[1] = strides[2] * sizes[2];
return strides;
default:
STANDALONE_INTERNAL_ASSERT(
false, "ChannelsLast3d doesn't support size ", sizes.size());
ET_DCHECK_MSG(false, "ChannelsLast3d doesn't support this size");
}
}

Expand Down
21 changes: 9 additions & 12 deletions backends/aoti/slim/c10/core/Scalar.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#pragma once

#include <cstdint>
#include <stdexcept>
#include <type_traits>
#include <utility>

#include <executorch/backends/aoti/slim/c10/core/ScalarType.h>
#include <executorch/backends/aoti/slim/c10/macros/Macros.h>
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
#include <executorch/backends/aoti/slim/c10/util/Half.h>
#include <executorch/backends/aoti/slim/c10/util/TypeCast.h>
#include <executorch/backends/aoti/slim/c10/util/complex.h>
#include <executorch/backends/aoti/slim/c10/util/overflows.h>
#include <executorch/runtime/platform/assert.h>

// Copy-pasted from c10/core/Scalar.h, but dropping SymScalar support

Expand Down Expand Up @@ -102,7 +101,7 @@ class Scalar {
} else if (Tag::HAS_u == tag) { \
return checked_convert<type, uint64_t>(v.u, #type); \
} \
STANDALONE_CHECK(false) \
ET_CHECK_MSG(false, "Unknown Scalar tag"); \
}

// TODO: Support ComplexHalf accessor
Expand Down Expand Up @@ -158,18 +157,17 @@ class Scalar {
}

Scalar operator-() const {
STANDALONE_CHECK(
ET_CHECK_MSG(
!isBoolean(),
"torch boolean negative, the `-` operator, is not supported.");
"torch boolean negative, the `-` operator, is not supported");
if (isFloatingPoint()) {
return Scalar(-v.d);
} else if (isComplex()) {
return Scalar(-v.z);
} else if (isIntegral(false)) {
return Scalar(-v.i);
}
STANDALONE_INTERNAL_ASSERT(
false, "unknown ivalue tag ", static_cast<int>(tag));
ET_CHECK_MSG(false, "unknown ivalue tag");
}

Scalar conj() const {
Expand All @@ -188,8 +186,7 @@ class Scalar {
} else if (isIntegral(false)) {
return std::log(v.i);
}
STANDALONE_INTERNAL_ASSERT(
false, "unknown ivalue tag ", static_cast<int>(tag));
ET_CHECK_MSG(false, "unknown ivalue tag");
}

template <
Expand Down Expand Up @@ -219,7 +216,7 @@ class Scalar {
// boolean scalar does not equal to a non boolean value
return false;
} else {
STANDALONE_INTERNAL_ASSERT(false);
ET_CHECK_MSG(false, "unexpected tag in equal");
}
}

Expand Down Expand Up @@ -249,7 +246,7 @@ class Scalar {
// boolean scalar does not equal to a non boolean value
return false;
} else {
STANDALONE_INTERNAL_ASSERT(false);
ET_CHECK_MSG(false, "unexpected tag in equal");
}
}

Expand All @@ -276,7 +273,7 @@ class Scalar {
} else if (isBoolean()) {
return executorch::backends::aoti::slim::c10::ScalarType::Bool;
} else {
throw std::runtime_error("Unknown scalar type.");
ET_CHECK_MSG(false, "Unknown scalar type.");
}
}

Expand Down
41 changes: 13 additions & 28 deletions backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <executorch/backends/aoti/slim/c10/util/Array.h>
#include <executorch/backends/aoti/slim/c10/util/BFloat16.h>
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
#include <executorch/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h>
#include <executorch/backends/aoti/slim/c10/util/Float8_e4m3fn.h>
#include <executorch/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h>
Expand All @@ -17,6 +16,7 @@
#include <executorch/backends/aoti/slim/c10/util/quint2x4.h>
#include <executorch/backends/aoti/slim/c10/util/quint4x2.h>
#include <executorch/backends/aoti/slim/c10/util/quint8.h>
#include <executorch/runtime/platform/assert.h>

#include <array>
#include <cstddef>
Expand Down Expand Up @@ -388,7 +388,7 @@ inline size_t elementSize(ScalarType t) {
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
default:
STANDALONE_CHECK(false, "Unknown ScalarType");
ET_CHECK_MSG(false, "Unknown ScalarType");
}
#undef CASE_ELEMENTSIZE_CASE
}
Expand Down Expand Up @@ -492,13 +492,13 @@ inline bool isSignedType(ScalarType t) {
case ScalarType::QInt32:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
STANDALONE_CHECK(false, "isSignedType not supported for quantized types");
ET_CHECK_MSG(false, "isSignedType not supported for quantized types");
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
STANDALONE_CHECK(false, "Bits types are undefined");
ET_CHECK_MSG(false, "Bits types are undefined");
CASE_ISSIGNED(UInt16);
CASE_ISSIGNED(UInt32);
CASE_ISSIGNED(UInt64);
Expand Down Expand Up @@ -543,7 +543,7 @@ inline bool isSignedType(ScalarType t) {
// Do not add default here, but rather define behavior of every new entry
// here. `-Wswitch-enum` would raise a warning in those cases.
}
STANDALONE_CHECK(false, "Unknown ScalarType ", t);
ET_CHECK_MSG(false, "Unknown ScalarType");
#undef CASE_ISSIGNED
}

Expand Down Expand Up @@ -583,7 +583,7 @@ inline ScalarType toComplexType(ScalarType t) {
case ScalarType::ComplexDouble:
return ScalarType::ComplexDouble;
default:
STANDALONE_CHECK(false, "Unknown Complex ScalarType for ", t);
ET_CHECK_MSG(false, "Unknown Complex ScalarType");
}
}

Expand Down Expand Up @@ -678,26 +678,16 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {

// Handle identically equal types
if (isQIntType(a) || isQIntType(b)) {
STANDALONE_CHECK(
false,
"promoteTypes with quantized numbers is not handled yet; figure out "
"what the correct rules should be, offending types: ",
toString(a),
" ",
toString(b));
ET_CHECK_MSG(
false, "promoteTypes with quantized numbers is not handled yet");
}

if (isBitsType(a) || isBitsType(b)) {
return ScalarType::Undefined;
}

if (isFloat8Type(a) || isFloat8Type(b)) {
STANDALONE_CHECK(
false,
"Promotion for Float8 Types is not supported, attempted to promote ",
toString(a),
" and ",
toString(b));
ET_CHECK_MSG(false, "Promotion for Float8 Types is not supported");
}

if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
Expand All @@ -717,18 +707,13 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
if (isFloatingType(b)) {
return b;
}
STANDALONE_CHECK(
false,
"Promotion for uint16, uint32, uint64 types is not supported, "
"attempted to promote ",
toString(a),
" and ",
toString(b));
ET_CHECK_MSG(
false, "Promotion for uint16, uint32, uint64 types is not supported");
}
auto ix_a = dtype2index[static_cast<int64_t>(a)];
STANDALONE_INTERNAL_ASSERT(ix_a != -1);
ET_DCHECK_MSG(ix_a != -1, "Invalid ScalarType a");
auto ix_b = dtype2index[static_cast<int64_t>(b)];
STANDALONE_INTERNAL_ASSERT(ix_b != -1);
ET_DCHECK_MSG(ix_b != -1, "Invalid ScalarType b");

// This table axes must be consistent with index2dtype
// clang-format off
Expand Down
Loading
Loading