Skip to content

Commit 250e4c9

Browse files
committed
[slimtensor migration 3/n] use et_check instead of standalone check
This stack aims to migrate slim tensor into ExecuTorch stack to make it as internal tensor representation of cudabackend. This diff replace origianl standalone_check with et_check to meet et standard Differential Revision: [D89445966](https://our.internmc.facebook.com/intern/diff/D89445966/) ghstack-source-id: 330180568 Pull Request resolved: #16314
1 parent 9200253 commit 250e4c9

22 files changed

+200
-339
lines changed

backends/aoti/slim/c10/core/Device.h

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

33
#include <executorch/backends/aoti/slim/c10/core/DeviceType.h>
4-
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
54
#include <executorch/backends/aoti/slim/c10/util/StringUtil.h>
5+
#include <executorch/runtime/platform/assert.h>
66

77
#include <algorithm>
88
#include <array>
@@ -70,12 +70,7 @@ inline DeviceType parse_type(const std::string& device_string) {
7070
device_names.push_back(it.first);
7171
}
7272
}
73-
STANDALONE_CHECK(
74-
false,
75-
"Expected one of ",
76-
Join(", ", device_names),
77-
" device type at start of device string: ",
78-
device_string);
73+
ET_CHECK_MSG(false, "Expected a valid device type at start of device string");
7974
}
8075
} // namespace detail
8176

@@ -111,7 +106,7 @@ struct Device final {
111106
/// where `cpu` or `cuda` specifies the device type, and
112107
/// `:<device-index>` optionally specifies a device index.
113108
/* implicit */ Device(const std::string& device_string) : Device(Type::CPU) {
114-
STANDALONE_CHECK(!device_string.empty(), "Device string must not be empty");
109+
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");
115110

116111
std::string device_name, device_index_str;
117112
detail::DeviceStringParsingState pstate =
@@ -170,21 +165,14 @@ struct Device final {
170165
(pstate == detail::DeviceStringParsingState::kINDEX_START &&
171166
device_index_str.empty());
172167

173-
STANDALONE_CHECK(
174-
!has_error, "Invalid device string: '", device_string, "'");
168+
ET_CHECK_MSG(!has_error, "Invalid device string");
175169

176170
try {
177171
if (!device_index_str.empty()) {
178172
index_ = static_cast<DeviceIndex>(std::stoi(device_index_str));
179173
}
180174
} catch (const std::exception&) {
181-
STANDALONE_CHECK(
182-
false,
183-
"Could not parse device index '",
184-
device_index_str,
185-
"' in device string '",
186-
device_string,
187-
"'");
175+
ET_CHECK_MSG(false, "Could not parse device index in device string");
188176
}
189177
type_ = detail::parse_type(device_name);
190178
validate();
@@ -326,13 +314,13 @@ struct Device final {
326314
// performance in micro-benchmarks.
327315
// This is safe to do, because backends that use the DeviceIndex
328316
// have a later check when we actually try to switch to that device.
329-
STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(
317+
ET_DCHECK_MSG(
330318
index_ >= -1,
331-
"Device index must be -1 or non-negative, got ",
319+
"Device index must be -1 or non-negative, got %d",
332320
static_cast<int>(index_));
333-
STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(
321+
ET_DCHECK_MSG(
334322
!is_cpu() || index_ <= 0,
335-
"CPU device index must be -1 or zero, got ",
323+
"CPU device index must be -1 or zero, got %d",
336324
static_cast<int>(index_));
337325
}
338326
};

backends/aoti/slim/c10/core/DeviceType.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <stdexcept>
1414
#include <string>
1515

16-
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
16+
#include <executorch/runtime/platform/assert.h>
1717

1818
namespace executorch::backends::aoti::slim::c10 {
1919
enum class DeviceType : int8_t {
@@ -94,7 +94,7 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
9494

9595
int idx = static_cast<int>(d);
9696
if (idx < 0 || idx >= COMPILE_TIME_MAX_DEVICE_TYPES) {
97-
STANDALONE_CHECK(false, "Unknown device: ", static_cast<int16_t>(d));
97+
ET_CHECK_MSG(false, "Unknown device");
9898
}
9999
if (d == DeviceType::PrivateUse1) {
100100
return get_privateuse1_backend(lower_case);

backends/aoti/slim/c10/core/Layout.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
3+
#include <executorch/runtime/platform/assert.h>
44

55
#include <cstdint>
66
#include <ostream>
@@ -46,7 +46,7 @@ inline std::ostream& operator<<(std::ostream& stream, c10::Layout layout) {
4646
case c10::kJagged:
4747
return stream << "Jagged";
4848
default:
49-
STANDALONE_CHECK(false, "Unknown layout");
49+
ET_CHECK_MSG(false, "Unknown layout");
5050
}
5151
}
5252

backends/aoti/slim/c10/core/MemoryFormat.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <executorch/backends/aoti/slim/c10/util/ArrayRef.h>
4-
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
4+
#include <executorch/runtime/platform/assert.h>
55

66
#include <cstdint>
77
#include <ostream>
@@ -57,7 +57,7 @@ inline std::ostream& operator<<(
5757
case MemoryFormat::ChannelsLast3d:
5858
return stream << "ChannelsLast3d";
5959
default:
60-
STANDALONE_CHECK(false, "Unknown memory format ", memory_format);
60+
ET_CHECK_MSG(false, "Unknown memory format");
6161
}
6262
}
6363

@@ -79,8 +79,7 @@ inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
7979
strides[1] = strides[2] * sizes[2];
8080
return strides;
8181
default:
82-
STANDALONE_INTERNAL_ASSERT(
83-
false, "ChannelsLast2d doesn't support size ", sizes.size());
82+
ET_DCHECK_MSG(false, "ChannelsLast2d doesn't support this size");
8483
}
8584
}
8685

@@ -106,8 +105,7 @@ std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
106105
strides[1] = strides[2] * sizes[2];
107106
return strides;
108107
default:
109-
STANDALONE_INTERNAL_ASSERT(
110-
false, "ChannelsLast3d doesn't support size ", sizes.size());
108+
ET_DCHECK_MSG(false, "ChannelsLast3d doesn't support this size");
111109
}
112110
}
113111

backends/aoti/slim/c10/core/Scalar.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
#include <executorch/backends/aoti/slim/c10/core/ScalarType.h>
99
#include <executorch/backends/aoti/slim/c10/macros/Macros.h>
10-
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
1110
#include <executorch/backends/aoti/slim/c10/util/Half.h>
1211
#include <executorch/backends/aoti/slim/c10/util/TypeCast.h>
1312
#include <executorch/backends/aoti/slim/c10/util/complex.h>
1413
#include <executorch/backends/aoti/slim/c10/util/overflows.h>
14+
#include <executorch/runtime/platform/assert.h>
1515

1616
// Copy-pasted from c10/core/Scalar.h, but dropping SymScalar support
1717

@@ -102,7 +102,7 @@ class Scalar {
102102
} else if (Tag::HAS_u == tag) { \
103103
return checked_convert<type, uint64_t>(v.u, #type); \
104104
} \
105-
STANDALONE_CHECK(false) \
105+
ET_CHECK_MSG(false, "Unknown Scalar tag"); \
106106
}
107107

108108
// TODO: Support ComplexHalf accessor
@@ -158,18 +158,17 @@ class Scalar {
158158
}
159159

160160
Scalar operator-() const {
161-
STANDALONE_CHECK(
161+
ET_CHECK_MSG(
162162
!isBoolean(),
163-
"torch boolean negative, the `-` operator, is not supported.");
163+
"torch boolean negative, the `-` operator, is not supported");
164164
if (isFloatingPoint()) {
165165
return Scalar(-v.d);
166166
} else if (isComplex()) {
167167
return Scalar(-v.z);
168168
} else if (isIntegral(false)) {
169169
return Scalar(-v.i);
170170
}
171-
STANDALONE_INTERNAL_ASSERT(
172-
false, "unknown ivalue tag ", static_cast<int>(tag));
171+
ET_CHECK_MSG(false, "unknown ivalue tag");
173172
}
174173

175174
Scalar conj() const {
@@ -188,8 +187,7 @@ class Scalar {
188187
} else if (isIntegral(false)) {
189188
return std::log(v.i);
190189
}
191-
STANDALONE_INTERNAL_ASSERT(
192-
false, "unknown ivalue tag ", static_cast<int>(tag));
190+
ET_CHECK_MSG(false, "unknown ivalue tag");
193191
}
194192

195193
template <
@@ -219,7 +217,7 @@ class Scalar {
219217
// boolean scalar does not equal to a non boolean value
220218
return false;
221219
} else {
222-
STANDALONE_INTERNAL_ASSERT(false);
220+
ET_CHECK_MSG(false, "unexpected tag in equal");
223221
}
224222
}
225223

@@ -249,7 +247,7 @@ class Scalar {
249247
// boolean scalar does not equal to a non boolean value
250248
return false;
251249
} else {
252-
STANDALONE_INTERNAL_ASSERT(false);
250+
ET_CHECK_MSG(false, "unexpected tag in equal");
253251
}
254252
}
255253

backends/aoti/slim/c10/core/ScalarType.h

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <executorch/backends/aoti/slim/c10/util/Array.h>
44
#include <executorch/backends/aoti/slim/c10/util/BFloat16.h>
5-
#include <executorch/backends/aoti/slim/c10/util/Exception.h>
65
#include <executorch/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h>
76
#include <executorch/backends/aoti/slim/c10/util/Float8_e4m3fn.h>
87
#include <executorch/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h>
@@ -17,6 +16,7 @@
1716
#include <executorch/backends/aoti/slim/c10/util/quint2x4.h>
1817
#include <executorch/backends/aoti/slim/c10/util/quint4x2.h>
1918
#include <executorch/backends/aoti/slim/c10/util/quint8.h>
19+
#include <executorch/runtime/platform/assert.h>
2020

2121
#include <array>
2222
#include <cstddef>
@@ -388,7 +388,7 @@ inline size_t elementSize(ScalarType t) {
388388
switch (t) {
389389
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
390390
default:
391-
STANDALONE_CHECK(false, "Unknown ScalarType");
391+
ET_CHECK_MSG(false, "Unknown ScalarType");
392392
}
393393
#undef CASE_ELEMENTSIZE_CASE
394394
}
@@ -492,13 +492,13 @@ inline bool isSignedType(ScalarType t) {
492492
case ScalarType::QInt32:
493493
case ScalarType::QUInt4x2:
494494
case ScalarType::QUInt2x4:
495-
STANDALONE_CHECK(false, "isSignedType not supported for quantized types");
495+
ET_CHECK_MSG(false, "isSignedType not supported for quantized types");
496496
case ScalarType::Bits1x8:
497497
case ScalarType::Bits2x4:
498498
case ScalarType::Bits4x2:
499499
case ScalarType::Bits8:
500500
case ScalarType::Bits16:
501-
STANDALONE_CHECK(false, "Bits types are undefined");
501+
ET_CHECK_MSG(false, "Bits types are undefined");
502502
CASE_ISSIGNED(UInt16);
503503
CASE_ISSIGNED(UInt32);
504504
CASE_ISSIGNED(UInt64);
@@ -543,7 +543,7 @@ inline bool isSignedType(ScalarType t) {
543543
// Do not add default here, but rather define behavior of every new entry
544544
// here. `-Wswitch-enum` would raise a warning in those cases.
545545
}
546-
STANDALONE_CHECK(false, "Unknown ScalarType ", t);
546+
ET_CHECK_MSG(false, "Unknown ScalarType");
547547
#undef CASE_ISSIGNED
548548
}
549549

@@ -583,7 +583,7 @@ inline ScalarType toComplexType(ScalarType t) {
583583
case ScalarType::ComplexDouble:
584584
return ScalarType::ComplexDouble;
585585
default:
586-
STANDALONE_CHECK(false, "Unknown Complex ScalarType for ", t);
586+
ET_CHECK_MSG(false, "Unknown Complex ScalarType");
587587
}
588588
}
589589

@@ -678,26 +678,16 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
678678

679679
// Handle identically equal types
680680
if (isQIntType(a) || isQIntType(b)) {
681-
STANDALONE_CHECK(
682-
false,
683-
"promoteTypes with quantized numbers is not handled yet; figure out "
684-
"what the correct rules should be, offending types: ",
685-
toString(a),
686-
" ",
687-
toString(b));
681+
ET_CHECK_MSG(
682+
false, "promoteTypes with quantized numbers is not handled yet");
688683
}
689684

690685
if (isBitsType(a) || isBitsType(b)) {
691686
return ScalarType::Undefined;
692687
}
693688

694689
if (isFloat8Type(a) || isFloat8Type(b)) {
695-
STANDALONE_CHECK(
696-
false,
697-
"Promotion for Float8 Types is not supported, attempted to promote ",
698-
toString(a),
699-
" and ",
700-
toString(b));
690+
ET_CHECK_MSG(false, "Promotion for Float8 Types is not supported");
701691
}
702692

703693
if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
@@ -717,18 +707,13 @@ inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
717707
if (isFloatingType(b)) {
718708
return b;
719709
}
720-
STANDALONE_CHECK(
721-
false,
722-
"Promotion for uint16, uint32, uint64 types is not supported, "
723-
"attempted to promote ",
724-
toString(a),
725-
" and ",
726-
toString(b));
710+
ET_CHECK_MSG(
711+
false, "Promotion for uint16, uint32, uint64 types is not supported");
727712
}
728713
auto ix_a = dtype2index[static_cast<int64_t>(a)];
729-
STANDALONE_INTERNAL_ASSERT(ix_a != -1);
714+
ET_DCHECK_MSG(ix_a != -1, "Invalid ScalarType a");
730715
auto ix_b = dtype2index[static_cast<int64_t>(b)];
731-
STANDALONE_INTERNAL_ASSERT(ix_b != -1);
716+
ET_DCHECK_MSG(ix_b != -1, "Invalid ScalarType b");
732717

733718
// This table axes must be consistent with index2dtype
734719
// clang-format off

0 commit comments

Comments
 (0)