From 636310d0cbd0339d40b9de2cd9bc430b3a61f285 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 1 Jun 2026 16:49:06 +0200 Subject: [PATCH] add shared loop creation helpers add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass --- src/PIM/CMakeLists.txt | 5 +- src/PIM/Common/CMakeLists.txt | 3 + src/PIM/Common/IR/LoopUtils.cpp | 96 +++ src/PIM/Common/IR/LoopUtils.hpp | 30 + src/PIM/Common/Support/CheckedArithmetic.cpp | 222 ++++++ src/PIM/Common/Support/CheckedArithmetic.hpp | 107 +++ src/PIM/Compiler/CMakeLists.txt | 5 +- src/PIM/Compiler/PimBinaryFormat.hpp | 13 +- src/PIM/Compiler/PimCodeGen.cpp | 72 +- src/PIM/Compiler/PimCompilerUtils.cpp | 2 +- .../Common/ComputeRegionBuilder.hpp | 8 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 92 ++- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 304 +++---- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 292 ++++--- .../ONNXToSpatial/Patterns/NN/Pool.cpp | 159 ++-- .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 78 +- .../ONNXToSpatial/Patterns/Post.cpp | 14 +- .../ONNXToSpatial/Patterns/Tensor/Resize.cpp | 152 ++-- .../BatchCoreLoweringPatterns.cpp | 65 +- src/PIM/Conversion/SpatialToPim/Common.cpp | 8 +- src/PIM/Conversion/SpatialToPim/Common.hpp | 4 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 32 +- .../SpatialToPim/Patterns/ChannelLowering.cpp | 17 +- .../SpatialToPim/ReturnPathNormalization.cpp | 106 ++- .../SpatialToPim/SpatialToPimPass.cpp | 81 +- .../SpatialToPim/SpatialToPimPass.hpp | 2 +- src/PIM/Dialect/Pim/CMakeLists.txt | 5 +- .../Bufferization/BufferizationUtils.cpp | 22 +- .../Bufferization/BufferizationUtils.hpp | 3 +- .../Pim/Transforms/Bufferization/Common.cpp | 9 +- .../Pim/Transforms/Bufferization/Common.hpp | 3 +- .../Bufferization/ContiguityPatterns.cpp | 62 +- .../OpBufferizationInterfaces.cpp | 53 +- .../Bufferization/PimBufferizationPass.cpp | 6 +- .../HostConstantFolding/CMakeLists.txt | 12 + .../HostConstantFolding/Common.cpp | 0 .../HostConstantFolding/Common.hpp | 0 .../HostConstantFoldingPass.cpp | 0 .../HostConstantFolding/Patterns.hpp | 0 .../HostConstantFolding/Patterns/Constant.cpp | 18 +- .../HostConstantFolding/Patterns/Subview.cpp | 0 .../CMakeLists.txt | 9 + .../MaterializeHostConstantsPass.cpp | 34 +- .../MemoryCoalescing/CMakeLists.txt | 14 + .../MemoryCoalescing.cpp} | 40 +- .../MemoryCoalescing.hpp} | 13 +- .../MemoryCoalescingPass.cpp} | 41 +- .../StaticMemoryCoalescing/CMakeLists.txt | 14 - .../Transforms/Verification/CMakeLists.txt | 11 + .../Verification}/VerificationPass.cpp | 0 .../MaterializeMergeSchedule.cpp | 753 ++++++++++-------- .../MergeComputeNodesPass.cpp | 9 +- src/PIM/Pass/CMakeLists.txt | 6 - src/PIM/Pass/PIMPasses.h | 2 +- src/PIM/PimAccelerator.cpp | 2 +- 55 files changed, 2007 insertions(+), 1103 deletions(-) create mode 100644 src/PIM/Common/IR/LoopUtils.cpp create mode 100644 src/PIM/Common/IR/LoopUtils.hpp create mode 100644 src/PIM/Common/Support/CheckedArithmetic.cpp create mode 100644 src/PIM/Common/Support/CheckedArithmetic.hpp create mode 100644 src/PIM/Dialect/Pim/Transforms/HostConstantFolding/CMakeLists.txt rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms}/HostConstantFolding/Common.cpp (100%) rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms}/HostConstantFolding/Common.hpp (100%) rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms}/HostConstantFolding/HostConstantFoldingPass.cpp (100%) rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms}/HostConstantFolding/Patterns.hpp (100%) rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms}/HostConstantFolding/Patterns/Constant.cpp (97%) rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms}/HostConstantFolding/Patterns/Subview.cpp (100%) create mode 100644 src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms/HostConstantMaterialization}/MaterializeHostConstantsPass.cpp (79%) create mode 100644 src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/CMakeLists.txt rename src/PIM/Dialect/Pim/Transforms/{StaticMemoryCoalescing/StaticMemoryCoalescing.cpp => MemoryCoalescing/MemoryCoalescing.cpp} (77%) rename src/PIM/Dialect/Pim/Transforms/{StaticMemoryCoalescing/StaticMemoryCoalescing.hpp => MemoryCoalescing/MemoryCoalescing.hpp} (55%) rename src/PIM/Dialect/Pim/Transforms/{StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp => MemoryCoalescing/MemoryCoalescingPass.cpp} (82%) delete mode 100644 src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/CMakeLists.txt create mode 100644 src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt rename src/PIM/{Pass/PimCodegen => Dialect/Pim/Transforms/Verification}/VerificationPass.cpp (100%) diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index 15c25a9..0bd0305 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -121,6 +121,9 @@ add_pim_library(OMPIMAccel OMSpatialToPim OMPimCommon OMPimBufferization - OMPimStaticMemoryCoalescing + OMPimMemoryCoalescing + OMPimHostConstantFolding + OMPimHostConstantMaterialization + OMPimVerification MLIRTensorInferTypeOpInterfaceImpl ) diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 81a02a6..3ef3168 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -5,9 +5,11 @@ add_pim_library(OMPimCommon IR/ConstantUtils.cpp IR/CoreBlockUtils.cpp IR/EntryPointUtils.cpp + IR/LoopUtils.cpp IR/ShapeUtils.cpp IR/SubviewUtils.cpp IR/WeightUtils.cpp + Support/CheckedArithmetic.cpp Support/DebugDump.cpp Support/Diagnostics.cpp Support/FileSystemUtils.cpp @@ -20,6 +22,7 @@ add_pim_library(OMPimCommon LINK_LIBS PUBLIC MLIRLinalgDialect + MLIRSCFDialect onnx SpatialOps PimOps diff --git a/src/PIM/Common/IR/LoopUtils.cpp b/src/PIM/Common/IR/LoopUtils.cpp new file mode 100644 index 0000000..514a5ee --- /dev/null +++ b/src/PIM/Common/IR/LoopUtils.cpp @@ -0,0 +1,96 @@ +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "llvm/Support/MathExtras.h" + +#include + +#include "ConstantUtils.hpp" +#include "LoopUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +static std::optional getStaticTripCount(Value lowerBound, Value upperBound, Value step) { + auto lower = matchConstantIndexValue(lowerBound); + auto upper = matchConstantIndexValue(upperBound); + auto stepValue = matchConstantIndexValue(step); + if (!lower || !upper || !stepValue) + return std::nullopt; + if (*stepValue <= 0) + return std::nullopt; + if (*upper <= *lower) + return int64_t {0}; + return llvm::divideCeil(*upper - *lower, *stepValue); +} + +} // namespace + +static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef yieldedValues) { + if (yieldedValues.size() == initArgs.size()) + return success(); + + emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size() + << " iter args"; + return failure(); +} + +FailureOr buildNormalizedScfFor(OpBuilder& builder, + Location loc, + Value lowerBound, + Value upperBound, + Value step, + ValueRange initArgs, + NormalizedLoopBodyBuilder bodyBuilder) { + NormalizedLoopResult result; + + if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) { + emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue; + return failure(); + } + + if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) { + if (*tripCount == 0) { + llvm::append_range(result.results, initArgs); + return result; + } + + if (*tripCount == 1) { + result.inductionVar = lowerBound; + if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results))) + return failure(); + if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) + return failure(); + return result; + } + } + + result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs); + result.inductionVar = result.loop.getInductionVar(); + + { + OpBuilder::InsertionGuard guard(builder); + Block* body = result.loop.getBody(); + if (!body->empty()) + if (auto yieldOp = dyn_cast(body->back())) + yieldOp->erase(); + builder.setInsertionPointToEnd(body); + ValueRange iterArgs = result.loop.getRegionIterArgs(); + if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) { + result.loop.erase(); + return failure(); + } + if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) { + result.loop.erase(); + return failure(); + } + scf::YieldOp::create(builder, loc, result.results); + } + builder.setInsertionPointAfter(result.loop); + result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end()); + return result; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/LoopUtils.hpp b/src/PIM/Common/IR/LoopUtils.hpp new file mode 100644 index 0000000..48ea178 --- /dev/null +++ b/src/PIM/Common/IR/LoopUtils.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" + +namespace onnx_mlir { + +struct NormalizedLoopResult { + mlir::Value inductionVar; + llvm::SmallVector results; + mlir::scf::ForOp loop; + + bool wasInlined() const { return !loop; } +}; + +using NormalizedLoopBodyBuilder = llvm::function_ref&)>; + +mlir::FailureOr buildNormalizedScfFor(mlir::OpBuilder& builder, + mlir::Location loc, + mlir::Value lowerBound, + mlir::Value upperBound, + mlir::Value step, + mlir::ValueRange initArgs, + NormalizedLoopBodyBuilder bodyBuilder); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/Support/CheckedArithmetic.cpp b/src/PIM/Common/Support/CheckedArithmetic.cpp new file mode 100644 index 0000000..77a8769 --- /dev/null +++ b/src/PIM/Common/Support/CheckedArithmetic.cpp @@ -0,0 +1,222 @@ +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include "CheckedArithmetic.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir::pim { + +namespace { + +static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) { + llvm::errs() << "PIM " << fieldName << " " << message << "\n"; +} + +template +static FailureOr checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) { + static_assert(std::is_integral_v && std::is_integral_v, "checkedCastAtLocation requires integral types"); + + using ToLimits = std::numeric_limits; + + if constexpr (std::is_signed_v == std::is_signed_v) { + if (value < static_cast(ToLimits::min()) || value > static_cast(ToLimits::max())) { + emitCheckedArithmeticError(loc, fieldName, "is outside representable range"); + return failure(); + } + } + else if constexpr (std::is_signed_v) { + using UnsignedFrom = std::make_unsigned_t; + using UnsignedTo = std::make_unsigned_t; + if (value < 0 || static_cast(value) > static_cast(ToLimits::max())) { + emitCheckedArithmeticError(loc, fieldName, "is outside representable range"); + return failure(); + } + } + else { + using UnsignedFrom = std::make_unsigned_t; + using UnsignedTo = std::conditional_t, std::make_unsigned_t, To>; + if (static_cast(value) > static_cast(ToLimits::max())) { + emitCheckedArithmeticError(loc, fieldName, "is outside representable range"); + return failure(); + } + } + + return static_cast(value); +} + +template +FailureOr checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) { + static_assert(std::is_integral_v && std::is_unsigned_v, + "checkedMulAtLocation requires unsigned integral types"); + if (lhs != 0 && rhs > std::numeric_limits::max() / lhs) { + emitCheckedArithmeticError(loc, fieldName, "multiplication overflow"); + return failure(); + } + return lhs * rhs; +} + +} // namespace + +InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) { + assert(anchor && "expected arithmetic diagnostics to have an anchor op"); + return anchor->emitOpError() << fieldName << " " << message; +} + +InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) { + return emitError(loc) << "PIM " << fieldName << " " << message; +} + +FailureOr checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) { + return checkedCast(value, anchor, fieldName); +} + +FailureOr checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) { + return checkedCast(value, anchor, fieldName); +} + +FailureOr checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) { + return checkedCast(value, anchor, fieldName); +} + +FailureOr checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) { + return checkedCast(value, anchor, fieldName); +} + +FailureOr +getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) { + assert(anchor && "checked op-based attrs require a non-null diagnostic anchor"); + auto checkedValue = checkedI32(value, anchor, fieldName); + if (failed(checkedValue)) + return failure(); + return builder.getI32IntegerAttr(*checkedValue); +} + +FailureOr +getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) { + assert(anchor && "checked op-based attrs require a non-null diagnostic anchor"); + auto checkedValue = checkedI32(value, anchor, fieldName); + if (failed(checkedValue)) + return failure(); + return builder.getI32IntegerAttr(*checkedValue); +} + +FailureOr getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) { + auto checkedValue = checkedCastAtLocation(value, loc, fieldName); + if (failed(checkedValue)) + return failure(); + return builder.getI32IntegerAttr(*checkedValue); +} + +FailureOr getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) { + auto checkedValue = checkedCastAtLocation(value, loc, fieldName); + if (failed(checkedValue)) + return failure(); + return builder.getI32IntegerAttr(*checkedValue); +} + +FailureOr getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) { + assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor"); + if (!type.hasStaticShape()) { + emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type"); + return failure(); + } + if (!hasByteSizedElementType(type.getElementType())) { + emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type"); + return failure(); + } + + uint64_t elements = 1; + for (int64_t dim : type.getShape()) { + if (dim < 0) { + emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions"); + return failure(); + } + + auto nextElements = checkedMul(elements, static_cast(dim), anchor, fieldName); + if (failed(nextElements)) + return failure(); + elements = *nextElements; + } + + return checkedMul( + elements, static_cast(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName); +} + +FailureOr getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) { + if (!type.hasStaticShape()) { + emitCheckedArithmeticError(loc, fieldName, "requires static shaped type"); + return failure(); + } + if (!hasByteSizedElementType(type.getElementType())) { + emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type"); + return failure(); + } + + uint64_t elements = 1; + for (int64_t dim : type.getShape()) { + if (dim < 0) { + emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions"); + return failure(); + } + + auto nextElements = checkedMulAtLocation(elements, static_cast(dim), loc, fieldName); + if (failed(nextElements)) + return failure(); + elements = *nextElements; + } + + return checkedMulAtLocation( + elements, static_cast(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName); +} + +int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) { + if (value < std::numeric_limits::min() || value > std::numeric_limits::max()) { + emitCrashMessage(fieldName, "is outside representable range"); + llvm_unreachable("PIM checked arithmetic failure"); + } + return static_cast(value); +} + +int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) { + if (value > static_cast(std::numeric_limits::max())) { + emitCrashMessage(fieldName, "is outside representable range"); + llvm_unreachable("PIM checked arithmetic failure"); + } + return static_cast(value); +} + +uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) { + if (value > static_cast(std::numeric_limits::max())) { + emitCrashMessage(fieldName, "is outside representable range"); + llvm_unreachable("PIM checked arithmetic failure"); + } + return static_cast(value); +} + +size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) { + if (value < 0) { + emitCrashMessage(fieldName, "is outside representable range"); + llvm_unreachable("PIM checked arithmetic failure"); + } + return static_cast(value); +} + +size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) { + if (rhs > std::numeric_limits::max() - lhs) { + emitCrashMessage(fieldName, "addition overflow"); + llvm_unreachable("PIM checked arithmetic failure"); + } + return lhs + rhs; +} + +size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) { + if (lhs != 0 && rhs > std::numeric_limits::max() / lhs) { + emitCrashMessage(fieldName, "multiplication overflow"); + llvm_unreachable("PIM checked arithmetic failure"); + } + return lhs * rhs; +} + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Common/Support/CheckedArithmetic.hpp b/src/PIM/Common/Support/CheckedArithmetic.hpp new file mode 100644 index 0000000..c070d94 --- /dev/null +++ b/src/PIM/Common/Support/CheckedArithmetic.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/StringRef.h" + +#include +#include +#include +#include + +namespace onnx_mlir::pim { + +mlir::InFlightDiagnostic +emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message); + +mlir::InFlightDiagnostic +emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message); + +template +mlir::FailureOr checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) { + static_assert(std::is_integral_v && std::is_integral_v, "checkedCast requires integral types"); + + using ToLimits = std::numeric_limits; + + if constexpr (std::is_signed_v == std::is_signed_v) { + if (value < static_cast(ToLimits::min()) || value > static_cast(ToLimits::max())) { + emitCheckedArithmeticError(anchor, fieldName, "is outside representable range"); + return mlir::failure(); + } + } + else if constexpr (std::is_signed_v) { + using UnsignedFrom = std::make_unsigned_t; + using UnsignedTo = std::make_unsigned_t; + if (value < 0 || static_cast(value) > static_cast(ToLimits::max())) { + emitCheckedArithmeticError(anchor, fieldName, "is outside representable range"); + return mlir::failure(); + } + } + else { + using UnsignedFrom = std::make_unsigned_t; + using UnsignedTo = std::conditional_t, std::make_unsigned_t, To>; + if (static_cast(value) > static_cast(ToLimits::max())) { + emitCheckedArithmeticError(anchor, fieldName, "is outside representable range"); + return mlir::failure(); + } + } + + return static_cast(value); +} + +template +mlir::FailureOr checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) { + static_assert(std::is_integral_v && std::is_unsigned_v, "checkedAdd requires unsigned integral types"); + if (rhs > std::numeric_limits::max() - lhs) { + emitCheckedArithmeticError(anchor, fieldName, "addition overflow"); + return mlir::failure(); + } + return lhs + rhs; +} + +template +mlir::FailureOr checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) { + static_assert(std::is_integral_v && std::is_unsigned_v, "checkedMul requires unsigned integral types"); + if (lhs != 0 && rhs > std::numeric_limits::max() / lhs) { + emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow"); + return mlir::failure(); + } + return lhs * rhs; +} + +mlir::FailureOr checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); +mlir::FailureOr checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); + +mlir::FailureOr checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); + +mlir::FailureOr checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); + +mlir::FailureOr +getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName); + +mlir::FailureOr +getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName); + +mlir::FailureOr +getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName); + +mlir::FailureOr +getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName); + +mlir::FailureOr +getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName); + +mlir::FailureOr +getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName); + +int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName); +int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName); +uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName); +size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName); +size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName); +size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName); + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Compiler/CMakeLists.txt b/src/PIM/Compiler/CMakeLists.txt index c40bac1..85a3453 100644 --- a/src/PIM/Compiler/CMakeLists.txt +++ b/src/PIM/Compiler/CMakeLists.txt @@ -28,7 +28,10 @@ add_pim_library(OMPimCompilerUtils OMPimCompilerOptions OMPimCommon OMPimBufferization - OMPimStaticMemoryCoalescing + OMPimMemoryCoalescing + OMPimHostConstantFolding + OMPimHostConstantMaterialization + OMPimVerification OMPimPasses OMONNXToSpatial OMSpatialToPim diff --git a/src/PIM/Compiler/PimBinaryFormat.hpp b/src/PIM/Compiler/PimBinaryFormat.hpp index 31b06a5..85d0fce 100644 --- a/src/PIM/Compiler/PimBinaryFormat.hpp +++ b/src/PIM/Compiler/PimBinaryFormat.hpp @@ -6,8 +6,8 @@ #include "llvm/Support/raw_ostream.h" #include -#include -#include + +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" namespace onnx_mlir::pim_binary { @@ -95,15 +95,10 @@ inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecor writeInt32LE(os, record.generic3); } -inline int32_t toI32(int64_t value) { - assert(value >= std::numeric_limits::min() && value <= std::numeric_limits::max() - && "PIM binary field out of int32 range"); - return static_cast(value); -} +inline int32_t toI32(int64_t value) { return onnx_mlir::pim::checkedI32OrCrash(value, "binary field"); } inline uint8_t toU8(int64_t value) { - assert(value >= 0 && value <= std::numeric_limits::max() && "PIM binary field out of uint8 range"); - return static_cast(value); + return onnx_mlir::pim::checkedU8OrCrash(static_cast(value), "binary field"); } inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) { diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 89d5c01..3c29ad9 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -25,12 +25,14 @@ #include #include #include +#include #include #include #include #include "Common/IR/CompactAsmUtils.hpp" #include "Common/PimCommon.hpp" +#include "Common/Support/CheckedArithmetic.hpp" #include "Common/Support/ReportUtils.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" @@ -71,12 +73,23 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional lane) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); - size_t allocSize = getShapedTypeSizeInBytes(type); + auto checkedAllocSize = + pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size"); + if (failed(checkedAllocSize)) + llvm_unreachable("Failed to compute checked allocation byte size"); + size_t allocSize = static_cast(*checkedAllocSize); MemEntry memEntry = {0, allocSize}; return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first; } @@ -272,7 +285,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, llvm_unreachable("Missing mem entry"); } - return iter->second.address + resolvedAddress->byteOffset; + size_t byteOffset = pim::checkedSizeOrCrash(resolvedAddress->byteOffset, "resolved PIM byte offset"); + return pim::checkedAddOrCrash(iter->second.address, byteOffset, "resolved PIM address"); } llvm::FailureOr PimAcceleratorMemory::getIndexValue(mlir::Value value, @@ -291,8 +305,12 @@ llvm::FailureOr PimAcceleratorMemory::getIndexValue(mlir::Value value, void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); } void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) { - reportEntries.push_back( - {MemoryReportEntry::Kind::Core, coreId, {static_cast(coreId)}, row, row.numAlloca, row.sizeAlloca}); + reportEntries.push_back({MemoryReportEntry::Kind::Core, + coreId, + {pim::checkedI32OrCrash(coreId, "memory report core id")}, + row, + row.numAlloca, + row.sizeAlloca}); } void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, @@ -402,24 +420,24 @@ void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t i pim_binary::InstructionRecord instruction; instruction.opcode = pim_binary::Opcode::sldi; instruction.rd = static_cast(registerNumber); - instruction.r2OrImm = static_cast(immediate); + instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate"); emitInstruction(instruction); } void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const { - genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); + genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address")); } void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const { - genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); - genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset); + genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address")); + genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address")); } void PimCodeGen::setupRdRs1Rs2( size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const { - genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); - genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset); - genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset); + genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address")); + genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address")); + genSetRegisterImmediateUnsigned(2, pim::checkedAddOrCrash(rs2Address, rs2Offset, "rs2 address")); } void PimCodeGen::emitMemCopyOp(StringRef opName, @@ -437,8 +455,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName, instruction.r1 = 1; instruction.generic1 = 0; instruction.generic2 = 0; - instruction.generic3 = static_cast(size); - (void) sizeFieldName; + instruction.generic3 = pim::checkedI32OrCrash(size, sizeFieldName); emitInstruction(instruction); } @@ -448,10 +465,10 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t pim_binary::InstructionRecord instruction; instruction.opcode = pim_binary::opcodeFromString(opName); instruction.rd = 0; - instruction.r2OrImm = static_cast(remapCoreId(coreId)); + instruction.r2OrImm = pim::checkedI32OrCrash(remapCoreId(coreId), "communication core id"); instruction.generic1 = 0; instruction.generic2 = 0; - instruction.generic3 = static_cast(size); + instruction.generic3 = pim::checkedI32OrCrash(size, "communication byte size"); emitInstruction(instruction); } @@ -464,7 +481,7 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_ instruction.r1 = 1; instruction.r2OrImm = 8; instruction.generic1 = 0; - instruction.generic2 = static_cast(groupId); + instruction.generic2 = pim::checkedI32OrCrash(groupId, "mvm group id"); emitInstruction(instruction); } @@ -578,7 +595,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvaddOp.getLhs().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vvaddOp.getLhs().getType())); emitInstruction(instruction); } @@ -593,7 +610,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvsubOp.getLhs().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vvsubOp.getLhs().getType())); emitInstruction(instruction); } @@ -608,7 +625,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvmulOp.getLhs().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vvmulOp.getLhs().getType())); emitInstruction(instruction); } @@ -623,7 +640,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvmaxOp.getLhs().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vvmaxOp.getLhs().getType())); emitInstruction(instruction); } @@ -638,7 +655,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvdmulOp.getLhs().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vvdmulOp.getLhs().getType())); emitInstruction(instruction); } @@ -653,7 +670,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge instruction.r1 = 1; instruction.r2OrImm = 1; instruction.generic1 = 1; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vavgOp.getInput().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vavgOp.getInput().getType())); emitInstruction(instruction); } @@ -666,7 +683,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle instruction.opcode = pim_binary::Opcode::vrelu; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vreluOp.getInput().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vreluOp.getInput().getType())); emitInstruction(instruction); } @@ -679,7 +696,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle instruction.opcode = pim_binary::Opcode::vtanh; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vtanhOp.getInput().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vtanhOp.getInput().getType())); emitInstruction(instruction); } @@ -692,7 +709,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle instruction.opcode = pim_binary::Opcode::vsigm; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vsigmOp.getInput().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vsigmOp.getInput().getType())); emitInstruction(instruction); } @@ -705,8 +722,7 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa instruction.opcode = pim_binary::Opcode::vsoftmax; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = - static_cast(getShapedTypeSizeInBytes(cast(vsoftmaxOp.getInput().getType()))); + instruction.generic3 = getVectorByteSizeOrCrash(cast(vsoftmaxOp.getInput().getType())); emitInstruction(instruction); } @@ -1370,7 +1386,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup)) return err; xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); - reportedCoreIds.push_back(static_cast(job.emittedCoreId)); + reportedCoreIds.push_back(pim::checkedI32OrCrash(job.emittedCoreId, "batch report core id")); if (!batchPerCoreRow) batchPerCoreRow = result.reportRow; batchRow = addMemoryReportRows(batchRow, result.reportRow); diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index d2e2ea0..2d111cc 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -40,7 +40,6 @@ void addPassesPim(OwningOpRef& module, if (pimEmissionTarget >= EmitPimBufferized) { pm.addPass(createPimBufferizationPass()); - pm.addPass(createPimStaticMemoryCoalescingPass()); pm.addPass(createMessagePass("Pim bufferized")); } @@ -48,6 +47,7 @@ void addPassesPim(OwningOpRef& module, pm.addPass(createPimHostConstantFoldingPass()); pm.addPass(createMessagePass("Pim host constants folded")); pm.addPass(createPimMaterializeHostConstantsPass()); + pm.addPass(createPimMemoryCoalescingPass()); pm.addPass(createPimVerificationPass()); pm.addPass(createMessagePass("Pim verified")); pm.addPass(createEmitPimCodePass()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp index bb5ba07..27d6471 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp @@ -12,6 +12,7 @@ #include #include +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -180,8 +181,11 @@ auto createSpatComputeBatch(RewriterT& rewriter, if (laneCount <= 0 || laneCount > std::numeric_limits::max()) return mlir::FailureOr(mlir::failure()); - auto batchOp = spatial::SpatComputeBatch::create( - rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast(laneCount)), weights, inputs); + auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count"); + if (mlir::failed(laneCountAttr)) + return mlir::FailureOr(mlir::failure()); + + auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs); mlir::SmallVector blockArgTypes {rewriter.getIndexType()}; mlir::SmallVector blockArgLocs {loc}; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index bda761c..0954a5c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -8,6 +8,7 @@ #include +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" @@ -305,58 +306,67 @@ static Value createIm2colRowComputes(Value x, auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); - auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); - rewriter.setInsertionPointToStart(im2colLoop.getBody()); + auto im2colLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cNumPatches, + c1, + ValueRange {im2colInit}, + [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value im2colAcc = iterArgs.front(); + Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch); + Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch); + Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); + Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); + Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); + Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth); - Value patchIndex = im2colLoop.getInductionVar(); - Value im2colAcc = im2colLoop.getRegionIterArgs().front(); + SmallVector offsets = { + batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = + tensor::ExtractSliceOp::create(rewriter, nestedLoc, patchType, paddedInput, offsets, sizes, strides); - Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); - Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); - Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); - Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); - Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); - Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); + Value row = tensor::CollapseShapeOp::create(rewriter, + nestedLoc, + im2colRowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); - SmallVector offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); - - Value row = tensor::CollapseShapeOp::create(rewriter, - loc, - im2colRowType, - patch, - SmallVector { - {0}, - {1, 2, 3} - }); - - SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; - SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; - SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value updatedIm2col = - tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); - scf::YieldOp::create(rewriter, loc, updatedIm2col); - - rewriter.setInsertionPointAfter(im2colLoop); - Value im2col = im2colLoop.getResult(0); + SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; + SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value updatedIm2col = + tensor::InsertSliceOp::create(rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); + yielded.push_back(updatedIm2col); + return success(); + }); + if (failed(im2colLoop)) + return failure(); + Value im2col = im2colLoop->results.front(); Value gemmInputRows = im2col; if (packFactor != 1) gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); + return success(); }); - return im2colComputeOp.getResult(0); + assert(succeeded(im2colComputeOp) && "Conv im2col compute construction must succeed"); + return im2colComputeOp->getResult(0); } static Value createCollectedConvOutput(ValueRange gemmRows, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 8fc6080..e073507 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -15,6 +15,7 @@ #include "Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" @@ -247,16 +248,16 @@ static Value createPaddedInputCompute(Value input, return computeOp.getResult(0); } -static spatial::SpatComputeBatch createVmmBatch(Value a, - Value b, - RankedTensorType aType, - RankedTensorType paddedBType, - RankedTensorType partialPiecesType, - int64_t numOutRows, - int64_t numKSlices, - int64_t numOutHSlices, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr createVmmBatch(Value a, + Value b, + RankedTensorType aType, + RankedTensorType paddedBType, + RankedTensorType partialPiecesType, + int64_t numOutRows, + int64_t numKSlices, + int64_t numOutHSlices, + ConversionPatternRewriter& rewriter, + Location loc) { const int64_t laneCount = partialPiecesType.getDimSize(0); auto batchOp = createSpatComputeBatch( rewriter, @@ -294,7 +295,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, createParallelInsertSliceIntoBatchOutput( rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides); }); - assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed"); + if (failed(batchOp)) + return failure(); return *batchOp; } @@ -416,15 +418,15 @@ static Value createBroadcastedBiasScalar(Value bias, return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult(); } -static spatial::SpatComputeBatch createVvdmulBatch(Value a, - Value b, - RankedTensorType aType, - RankedTensorType bType, - RankedTensorType scalarPiecesType, - RankedTensorType outType, - bool bAlreadyTransposed, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr createVvdmulBatch(Value a, + Value b, + RankedTensorType aType, + RankedTensorType bType, + RankedTensorType scalarPiecesType, + RankedTensorType outType, + bool bAlreadyTransposed, + ConversionPatternRewriter& rewriter, + Location loc) { const int64_t numOutRows = outType.getDimSize(0); const int64_t numOutCols = outType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1); @@ -454,26 +456,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a, createParallelInsertSliceIntoBatchOutput( rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides); }); - assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed"); + if (failed(batchOp)) + return failure(); return *batchOp; } -static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, - Value bias, - RankedTensorType scalarPiecesType, - RankedTensorType biasType, - RankedTensorType outType, - float alpha, - float beta, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr createDynamicGemmOutputCompute(Value scalarPieces, + Value bias, + RankedTensorType scalarPiecesType, + RankedTensorType biasType, + RankedTensorType outType, + float alpha, + float beta, + ConversionPatternRewriter& rewriter, + Location loc) { const int64_t laneCount = scalarPiecesType.getDimSize(0); const int64_t numOutCols = outType.getDimSize(1); SmallVector inputs {scalarPieces}; if (bias) inputs.push_back(bias); - return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) { + return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult { Value pieces = blockArgs[0]; Value biasArg = bias ? blockArgs[1] : Value(); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); @@ -481,40 +484,50 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); - auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); - rewriter.setInsertionPointToStart(loop.getBody()); + auto loop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cLaneCount, + c1, + ValueRange {outputInit}, + [&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value outputAcc = iterArgs.front(); + Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, nestedLoc); + Value column = + onnx_mlir::affineModConst(rewriter, nestedLoc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); + SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value scalar = tensor::ExtractSliceOp::create( + rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides) + .getResult(); + if (alpha != 1.0f) { + Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, nestedLoc); + scalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, scalar, alphaTensor).getResult(); + } + if (biasArg) { + Value biasScalar = + createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, nestedLoc); + if (beta != 1.0f) { + Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, nestedLoc); + biasScalar = + spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, biasScalar, betaTensor).getResult(); + } + scalar = spatial::SpatVAddOp::create(rewriter, nestedLoc, scalarType, scalar, biasScalar).getResult(); + } + SmallVector outputOffsets {row, column}; + Value outputNext = + tensor::InsertSliceOp::create(rewriter, nestedLoc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides) + .getResult(); + yielded.push_back(outputNext); + return success(); + }); + if (failed(loop)) + return failure(); - Value lane = loop.getInductionVar(); - Value outputAcc = loop.getRegionIterArgs().front(); - Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); - Value column = - onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); - SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; - SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value scalar = - tensor::ExtractSliceOp::create(rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides) - .getResult(); - if (alpha != 1.0f) { - Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, loc); - scalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, scalar, alphaTensor).getResult(); - } - if (biasArg) { - Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, loc); - if (beta != 1.0f) { - Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, loc); - biasScalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, biasScalar, betaTensor).getResult(); - } - scalar = spatial::SpatVAddOp::create(rewriter, loc, scalarType, scalar, biasScalar).getResult(); - } - SmallVector outputOffsets {row, column}; - Value outputNext = - tensor::InsertSliceOp::create(rewriter, loc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides) - .getResult(); - scf::YieldOp::create(rewriter, loc, outputNext); - - rewriter.setInsertionPointAfter(loop); - spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0)); + spatial::SpatYieldOp::create(rewriter, loc, loop->results.front()); + return success(); }); } @@ -579,85 +592,92 @@ static Value reducePartialPiecesForHSlice(Value partialPiecesArg, return activePieces.front(); } -static spatial::SpatCompute createReductionCompute(Value partialPieces, - Value bias, - RankedTensorType partialPiecesType, - RankedTensorType outType, - RankedTensorType paddedOutType, - int64_t numKSlices, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr createReductionCompute(Value partialPieces, + Value bias, + RankedTensorType partialPiecesType, + RankedTensorType outType, + RankedTensorType paddedOutType, + int64_t numKSlices, + ConversionPatternRewriter& rewriter, + Location loc) { SmallVector inputs {partialPieces}; if (bias) inputs.push_back(bias); - auto computeOp = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) { - Value partialPiecesArg = blockArgs[0]; - Value biasArg = bias ? blockArgs[1] : Value(); - if (biasArg && cast(biasArg.getType()) != paddedOutType) - biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc); + auto computeOp = + createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult { + Value partialPiecesArg = blockArgs[0]; + Value biasArg = bias ? blockArgs[1] : Value(); + if (biasArg && cast(biasArg.getType()) != paddedOutType) + biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc); - const int64_t numOutRows = outType.getDimSize(0); - const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue()); - auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, - partialPiecesType.getElementType()); + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue()); + auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, + partialPiecesType.getElementType()); - Value outputInit = - tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult(); - SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - SmallVector pieceSizes {rewriter.getIndexAttr(numOutRows), - rewriter.getIndexAttr(crossbarSize.getValue())}; + Value outputInit = + tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult(); + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector pieceSizes {rewriter.getIndexAttr(numOutRows), + rewriter.getIndexAttr(crossbarSize.getValue())}; - auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { - Value reduced = - reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); - Value hOffset = onnx_mlir::affineMulConst( - rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); - if (biasArg) { - SmallVector biasOffsets {rewriter.getIndexAttr(0), hOffset}; - Value biasSlice = - tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides) - .getResult(); - reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult(); + auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { + Value reduced = + reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); + Value hOffset = onnx_mlir::affineMulConst( + rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); + if (biasArg) { + SmallVector biasOffsets {rewriter.getIndexAttr(0), hOffset}; + Value biasSlice = + tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides) + .getResult(); + reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult(); + } + + SmallVector outputOffsets {rewriter.getIndexAttr(0), hOffset}; + return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides) + .getResult(); + }; + + Value paddedOutput = outputInit; + if (numOutHSlices == 1) { + Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + paddedOutput = buildOutputSlice(outputInit, hSlice); + } + else { + Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cOutHSlices = + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); + auto hLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cOutHSlices, + c1, + ValueRange {outputInit}, + [&](OpBuilder&, Location, Value hSlice, ValueRange iterArgs, SmallVectorImpl& yielded) { + yielded.push_back(buildOutputSlice(iterArgs.front(), hSlice)); + return success(); + }); + if (failed(hLoop)) + return failure(); + paddedOutput = hLoop->results.front(); } - SmallVector outputOffsets {rewriter.getIndexAttr(0), hOffset}; - return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides) - .getResult(); - }; - - Value paddedOutput = outputInit; - if (numOutHSlices == 1) { - Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - paddedOutput = buildOutputSlice(outputInit, hSlice); - } - else { - Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); - Value cOutHSlices = - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); - auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); - rewriter.setInsertionPointToStart(hLoop.getBody()); - - Value hSlice = hLoop.getInductionVar(); - Value outputAcc = hLoop.getRegionIterArgs().front(); - scf::YieldOp::create(rewriter, loc, buildOutputSlice(outputAcc, hSlice)); - - rewriter.setInsertionPointAfter(hLoop); - paddedOutput = hLoop.getResult(0); - } - - Value result = paddedOutput; - if (paddedOutType != outType) { - SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; - SmallVector outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)), - rewriter.getIndexAttr(outType.getDimSize(1))}; - result = - tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides) - .getResult(); - } - spatial::SpatYieldOp::create(rewriter, loc, result); - }); + Value result = paddedOutput; + if (paddedOutType != outType) { + SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)), + rewriter.getIndexAttr(outType.getDimSize(1))}; + result = + tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides) + .getResult(); + } + spatial::SpatYieldOp::create(rewriter, loc, result); + return success(); + }); return computeOp; } @@ -755,9 +775,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType()); auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc); + if (failed(batchOp)) + return failure(); auto outputCompute = createDynamicGemmOutputCompute( - batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc); - rewriter.replaceOp(gemmOp, outputCompute.getResults()); + batchOp->getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc); + if (failed(outputCompute)) + return failure(); + rewriter.replaceOp(gemmOp, outputCompute->getResults()); return success(); } @@ -832,10 +856,14 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, RankedTensorType::get({laneCount64, static_cast(crossbarSize.getValue())}, outType.getElementType()); auto batchOp = createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc); + if (failed(batchOp)) + return failure(); auto reductionCompute = createReductionCompute( - batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc); + batchOp->getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc); + if (failed(reductionCompute)) + return failure(); - rewriter.replaceOp(gemmOp, reductionCompute.getResults()); + rewriter.replaceOp(gemmOp, reductionCompute->getResults()); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 57c99b4..c1dba86 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -8,6 +8,7 @@ #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" @@ -281,18 +282,18 @@ static Value getBatchLaneIndex( rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp()); } -static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, - Value b, - RankedTensorType aType, - int64_t aBatchCount, - RankedTensorType bType, - int64_t bBatchCount, - RankedTensorType partialPiecesType, - int64_t numOutRows, - int64_t numKSlices, - int64_t numOutHSlices, - PatternRewriter& rewriter, - Location loc) { +static FailureOr createBatchedVmmBatch(Value a, + Value b, + RankedTensorType aType, + int64_t aBatchCount, + RankedTensorType bType, + int64_t bBatchCount, + RankedTensorType partialPiecesType, + int64_t numOutRows, + int64_t numKSlices, + int64_t numOutHSlices, + PatternRewriter& rewriter, + Location loc) { const int64_t laneCount = partialPiecesType.getDimSize(0); auto batchOp = createSpatComputeBatch( rewriter, @@ -331,7 +332,8 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, createParallelInsertSliceIntoBatchOutput( rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); }); - assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed"); + if (failed(batchOp)) + return failure(); return *batchOp; } @@ -422,17 +424,17 @@ static Value extractDynamicBatchedRowVector(Value matrix, }); } -static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a, - int64_t aBatchCount, - Value b, - int64_t bBatchCount, - RankedTensorType aType, - RankedTensorType bType, - RankedTensorType scalarPiecesType, - RankedTensorType outType, - bool bAlreadyTransposed, - PatternRewriter& rewriter, - Location loc) { +static FailureOr createBatchedVvdmulBatch(Value a, + int64_t aBatchCount, + Value b, + int64_t bBatchCount, + RankedTensorType aType, + RankedTensorType bType, + RankedTensorType scalarPiecesType, + RankedTensorType outType, + bool bAlreadyTransposed, + PatternRewriter& rewriter, + Location loc) { const int64_t numBatches = outType.getDimSize(0); const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutCols = outType.getDimSize(2); @@ -466,64 +468,73 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a, createParallelInsertSliceIntoBatchOutput( rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2)); }); - assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed"); + if (failed(batchOp)) + return failure(); return *batchOp; } -static Value createBatchedDynamicOutputCompute(Value scalarPieces, - RankedTensorType scalarPiecesType, - RankedTensorType outType, - PatternRewriter& rewriter, - Location loc) { +static FailureOr createBatchedDynamicOutputCompute(Value scalarPieces, + RankedTensorType scalarPiecesType, + RankedTensorType outType, + PatternRewriter& rewriter, + Location loc) { const int64_t laneCount = scalarPiecesType.getDimSize(0); const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutCols = outType.getDimSize(2); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType()); - auto computeOp = - createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) { + auto computeOp = createSpatCompute<1>( + rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) -> LogicalResult { Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); - auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); - rewriter.setInsertionPointToStart(loop.getBody()); - - Value lane = loop.getInductionVar(); - Value outputAcc = loop.getRegionIterArgs().front(); - Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp); - Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp); - Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp); - Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp); - SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; - SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value scalar = tensor::ExtractSliceOp::create( - rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2)); - Value expanded = tensor::ExpandShapeOp::create(rewriter, - loc, - outputScalarType, - scalar, - SmallVector { - {0}, - {1, 2} - }); - SmallVector outputOffsets {batch, row, column}; - SmallVector outputSizes { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - scf::YieldOp::create( + auto loop = buildNormalizedScfFor( rewriter, loc, - tensor::InsertSliceOp::create( - rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) - .getResult()); - - rewriter.setInsertionPointAfter(loop); - spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0)); + c0, + cLaneCount, + c1, + ValueRange {outputInit}, + [&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value outputAcc = iterArgs.front(); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value batch = affineFloorDivConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp); + Value batchLane = affineModConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp); + Value row = affineFloorDivConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp); + Value column = affineModConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp); + SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value scalar = tensor::ExtractSliceOp::create( + rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2)); + Value expanded = tensor::ExpandShapeOp::create(rewriter, + nestedLoc, + outputScalarType, + scalar, + SmallVector { + {0}, + {1, 2} + }); + SmallVector outputOffsets {batch, row, column}; + SmallVector outputSizes = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value next = + tensor::InsertSliceOp::create( + rewriter, nestedLoc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) + .getResult(); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, loop->results.front()); + return success(); }); - return computeOp.getResult(0); + if (failed(computeOp)) + return failure(); + return computeOp->getResult(0); } static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) { @@ -587,16 +598,16 @@ static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg, return activePieces.front(); } -static Value createBatchedReductionCompute(Value partialPieces, - RankedTensorType partialPiecesType, - RankedTensorType outType, - RankedTensorType paddedOutType, - int64_t numBatches, - int64_t numKSlices, - PatternRewriter& rewriter, - Location loc) { +static FailureOr createBatchedReductionCompute(Value partialPieces, + RankedTensorType partialPiecesType, + RankedTensorType outType, + RankedTensorType paddedOutType, + int64_t numBatches, + int64_t numKSlices, + PatternRewriter& rewriter, + Location loc) { auto computeOp = createSpatCompute<1>( - rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) { + rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) -> LogicalResult { const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue()); auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, @@ -612,43 +623,55 @@ static Value createBatchedReductionCompute(Value partialPieces, Value cNumOutHSlices = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); - auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit}); - rewriter.setInsertionPointToStart(batchLoop.getBody()); - Value batch = batchLoop.getInductionVar(); - Value batchAcc = batchLoop.getRegionIterArgs().front(); - - auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc}); - rewriter.setInsertionPointToStart(hLoop.getBody()); - Value hSlice = hLoop.getInductionVar(); - Value outputAcc = hLoop.getRegionIterArgs().front(); - - Value reduced = reduceBatchedPartialPiecesForHSlice( - partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc); - Value expandedReduced = tensor::ExpandShapeOp::create(rewriter, - loc, - outputSliceType, - reduced, - SmallVector { - {0, 1}, - {2} - }); - Value hOffset = - affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); - SmallVector outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; - SmallVector outputSizes { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; - scf::YieldOp::create( + auto batchLoop = buildNormalizedScfFor( rewriter, loc, - tensor::InsertSliceOp::create( - rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) - .getResult()); - - rewriter.setInsertionPointAfter(hLoop); - scf::YieldOp::create(rewriter, loc, hLoop.getResult(0)); - - rewriter.setInsertionPointAfter(batchLoop); - Value paddedOutput = batchLoop.getResult(0); + c0, + cNumBatches, + c1, + ValueRange {outputInit}, + [&]( + OpBuilder&, Location batchLoc, Value batch, ValueRange batchIterArgs, SmallVectorImpl& batchYielded) { + auto hLoop = buildNormalizedScfFor( + rewriter, + batchLoc, + c0, + cNumOutHSlices, + c1, + ValueRange {batchIterArgs.front()}, + [&](OpBuilder&, Location hLoc, Value hSlice, ValueRange hIterArgs, SmallVectorImpl& hYielded) { + Value outputAcc = hIterArgs.front(); + Value reduced = reduceBatchedPartialPiecesForHSlice( + partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, hLoc); + Value expandedReduced = tensor::ExpandShapeOp::create(rewriter, + hLoc, + outputSliceType, + reduced, + SmallVector { + {0, 1}, + {2} + }); + Value hOffset = affineMulConst( + rewriter, hLoc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); + SmallVector outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; + SmallVector outputSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numOutRows), + rewriter.getIndexAttr(crossbarSize.getValue())}; + Value next = + tensor::InsertSliceOp::create( + rewriter, hLoc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) + .getResult(); + hYielded.push_back(next); + return success(); + }); + if (failed(hLoop)) + return failure(); + batchYielded.push_back(hLoop->results.front()); + return success(); + }); + if (failed(batchLoop)) + return failure(); + Value paddedOutput = batchLoop->results.front(); Value result = paddedOutput; if (paddedOutType != outType) { SmallVector outputOffsets { @@ -660,8 +683,11 @@ static Value createBatchedReductionCompute(Value partialPieces, rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)); } spatial::SpatYieldOp::create(rewriter, loc, result); + return success(); }); - return computeOp.getResult(0); + if (failed(computeOp)) + return failure(); + return computeOp->getResult(0); } struct MatMulShapeInfo { @@ -841,22 +867,27 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { numOutHSlices, rewriter, loc); - Value result = createBatchedReductionCompute(batchOp.getResult(0), - partialPiecesType, - directOutType, - paddedOutType, - shapeInfo->batch, - numKSlices, - rewriter, - loc); + if (failed(batchOp)) + return failure(); + auto result = createBatchedReductionCompute(batchOp->getResult(0), + partialPiecesType, + directOutType, + paddedOutType, + shapeInfo->batch, + numKSlices, + rewriter, + loc); + if (failed(result)) + return failure(); + Value finalResult = *result; if (useTransposedForm) - result = transposeBatchedOutput( - result, + finalResult = transposeBatchedOutput( + finalResult, RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), rewriter, loc); - result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); - rewriter.replaceOp(matmulOp, result); + finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + rewriter.replaceOp(matmulOp, finalResult); return success(); } } @@ -873,16 +904,21 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { false, rewriter, loc); - Value result = - createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc); + if (failed(batchOp)) + return failure(); + auto result = + createBatchedDynamicOutputCompute(batchOp->getResult(0), scalarPiecesType, directOutType, rewriter, loc); + if (failed(result)) + return failure(); + Value finalResult = *result; if (useTransposedForm) - result = transposeBatchedOutput( - result, + finalResult = transposeBatchedOutput( + finalResult, RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), rewriter, loc); - result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); - rewriter.replaceOp(matmulOp, result); + finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + rewriter.replaceOp(matmulOp, finalResult); return success(); } }; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index 2cab22e..596523c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -12,6 +12,7 @@ #include #include +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" @@ -275,86 +276,102 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); - auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); - rewriter.setInsertionPointToStart(outputLoop.getBody()); + auto outputLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cOutputPatchCount, + c1, + ValueRange {pooledOutputInit}, + [&](OpBuilder&, + Location nestedLoc, + Value outputPatchIndex, + ValueRange iterArgs, + SmallVectorImpl& yielded) { + Value pooledOutputAcc = iterArgs.front(); + Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch); + Value batchPatchIndex = + arith::RemUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch); + Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth); + Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth); + Value windowBaseH = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); + Value windowBaseW = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth); - Value outputPatchIndex = outputLoop.getInductionVar(); - Value pooledOutputAcc = outputLoop.getRegionIterArgs().front(); + Value updatedOutput = pooledOutputAcc; + for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { + const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); + auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); + Value reducedWindow = + createPoolFillTensor(rewriter, nestedLoc, tileType, std::is_same_v); - Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); - Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); - Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); - Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); - Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); - Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + Value paddedInH = windowBaseH; + if (kernelH * dilationHeight != 0) { + Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight); + paddedInH = arith::AddIOp::create(rewriter, nestedLoc, paddedInH, kernelHOffset); + } - Value updatedOutput = pooledOutputAcc; - for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { - const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); - auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); - Value reducedWindow = - createPoolFillTensor(rewriter, loc, tileType, std::is_same_v); + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + Value paddedInW = windowBaseW; + if (kernelW * dilationWidth != 0) { + Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth); + paddedInW = arith::AddIOp::create(rewriter, nestedLoc, paddedInW, kernelWOffset); + } - for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { - Value paddedInH = windowBaseH; - if (kernelH * dilationHeight != 0) { - Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight); - paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); - } - - for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { - Value paddedInW = windowBaseW; - if (kernelW * dilationWidth != 0) { - Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth); - paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); + SmallVector offsets = { + batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value windowValue = + tensor::ExtractSliceOp::create(rewriter, nestedLoc, tileType, paddedInput, offsets, sizes, strides); + windowValue = materializeTileTensor(rewriter, nestedLoc, windowValue); + reducedWindow = ReduceOp::create(rewriter, nestedLoc, tileType, reducedWindow, windowValue); + } } - SmallVector offsets = { - batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(tileChannels), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - SmallVector strides = { + if constexpr (std::is_same_v) { + SmallVector scaleOffsets = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(channelTile * xbarSize), + outHeightIndex, + outWidthIndex}; + SmallVector scaleSizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector scaleStrides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value scaleSlice = tensor::ExtractSliceOp::create( + rewriter, nestedLoc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); + scaleSlice = materializeTileTensor(rewriter, nestedLoc, scaleSlice); + reducedWindow = spatial::SpatVMulOp::create(rewriter, nestedLoc, tileType, reducedWindow, scaleSlice); + } + + SmallVector outputOffsets = { + batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; + SmallVector outputSizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector outputStrides = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value windowValue = - tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); - windowValue = materializeTileTensor(rewriter, loc, windowValue); - reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue); + updatedOutput = tensor::InsertSliceOp::create( + rewriter, nestedLoc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides); } - } + yielded.push_back(updatedOutput); + return success(); + }); + if (failed(outputLoop)) + return failure(); - if constexpr (std::is_same_v) { - SmallVector scaleOffsets = { - rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; - SmallVector scaleSizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(tileChannels), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - SmallVector scaleStrides = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value scaleSlice = tensor::ExtractSliceOp::create( - rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); - scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice); - reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); - } - - SmallVector outputOffsets = { - batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; - SmallVector outputSizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(tileChannels), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - SmallVector outputStrides = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - updatedOutput = tensor::InsertSliceOp::create( - rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides); - } - - scf::YieldOp::create(rewriter, loc, updatedOutput); - - rewriter.setInsertionPointAfter(outputLoop); - spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0)); + spatial::SpatYieldOp::create(rewriter, loc, outputLoop->results.front()); return success(); }); if (failed(computeOp)) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index 7657ea7..374bd8e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -42,13 +43,13 @@ static Value buildLoopSoftmaxSlice(Value input, return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides); } -static Value buildLoopSoftmaxNest(Value input, - Value accumulator, - RankedTensorType inputType, - int64_t axis, - SmallVectorImpl& outerIndices, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr buildLoopSoftmaxNest(Value input, + Value accumulator, + RankedTensorType inputType, + int64_t axis, + SmallVectorImpl& outerIndices, + ConversionPatternRewriter& rewriter, + Location loc) { if (axis == inputType.getRank() - 1) return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc); @@ -57,38 +58,50 @@ static Value buildLoopSoftmaxNest(Value input, Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis)); - auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator}); - rewriter.setInsertionPointToStart(loop.getBody()); - - Value loopIndex = loop.getInductionVar(); - Value loopAccumulator = loop.getRegionIterArgs().front(); - outerIndices.push_back(loopIndex); - Value updatedAccumulator = - buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc); - outerIndices.pop_back(); - - scf::YieldOp::create(rewriter, loc, updatedAccumulator); - rewriter.setInsertionPointAfter(loop); - return loop.getResult(0); + auto loop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cUpper, + c1, + ValueRange {accumulator}, + [&](OpBuilder& builder, Location nestedLoc, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + outerIndices.push_back(loopIndex); + auto updatedAccumulator = + buildLoopSoftmaxNest(input, iterArgs.front(), inputType, axis + 1, outerIndices, rewriter, nestedLoc); + outerIndices.pop_back(); + if (failed(updatedAccumulator)) + return failure(); + yielded.push_back(*updatedAccumulator); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); } -static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { +static FailureOr createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); constexpr size_t numInputs = 1; - auto computeOp = - createSpatCompute(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { + auto computeOp = createSpatCompute( + rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) -> LogicalResult { if (inputType.getRank() == 1) { Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult(); spatial::SpatYieldOp::create(rewriter, loc, softmax); - return; + return success(); } Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType()); SmallVector outerIndices; - Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc); - spatial::SpatYieldOp::create(rewriter, loc, result); + auto result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc); + if (failed(result)) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, *result); + return success(); }); - return computeOp.getResult(0); + if (failed(computeOp)) + return failure(); + return computeOp->getResult(0); } struct SoftmaxToSpatialCompute : OpConversionPattern { @@ -108,7 +121,10 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { Value input = adaptor.getInput(); Value result; if (*axis == inputType.getRank() - 1) { - result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); + auto computed = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); + if (failed(computed)) + return failure(); + result = *computed; } else { SmallVector permutation; @@ -122,8 +138,10 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { auto transposedType = RankedTensorType::get( permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc()); - Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); - result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc()); + auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); + if (failed(transposedResult)) + return failure(); + result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc()); } rewriter.replaceOp(softmaxOp, result); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp index f42f4d3..d180ce5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp @@ -9,6 +9,7 @@ #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -192,13 +193,12 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(compute.getLaneCount())), - promoted->newWeights, - promoted->newInputs); + auto laneCountAttr = pim::getCheckedI32Attr( + rewriter, compute, static_cast(compute.getLaneCount()), "promoted compute_batch lane count"); + if (failed(laneCountAttr)) + return failure(); + auto newCompute = spatial::SpatComputeBatch::create( + rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs); auto laneArg = compute.getLaneArgument(); if (!laneArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index b8f8dbf..769a943 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -5,6 +5,7 @@ #include "llvm/ADT/STLExtras.h" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -26,11 +27,11 @@ static Value buildNearestAsymmetricIndex( return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast); } -static Value buildNearestResizeLoop(Value input, - RankedTensorType inputType, - RankedTensorType resultType, - ConversionPatternRewriter& rewriter, - Location loc) { +static FailureOr buildNearestResizeLoop(Value input, + RankedTensorType inputType, + RankedTensorType resultType, + ConversionPatternRewriter& rewriter, + Location loc) { auto elemType = resultType.getElementType(); SmallVector unitShape(resultType.getRank(), 1); auto unitTensorType = RankedTensorType::get(unitShape, elemType); @@ -48,54 +49,94 @@ static Value buildNearestResizeLoop(Value input, Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType); - auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit}); - rewriter.setInsertionPointToStart(batchLoop.getBody()); + auto batchLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cOutputN, + c1, + ValueRange {outputInit}, + [&](OpBuilder&, Location nestedLoc, Value outputN, ValueRange batchIterArgs, SmallVectorImpl& batchYielded) { + Value outputBatchAcc = batchIterArgs.front(); + Value inputN = + buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, nestedLoc); - Value outputN = batchLoop.getInductionVar(); - Value outputBatchAcc = batchLoop.getRegionIterArgs().front(); - Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc); + auto channelLoop = buildNormalizedScfFor( + rewriter, + nestedLoc, + c0, + cOutputC, + c1, + ValueRange {outputBatchAcc}, + [&](OpBuilder&, + Location channelLoc, + Value outputC, + ValueRange channelIterArgs, + SmallVectorImpl& channelYielded) { + Value outputChannelAcc = channelIterArgs.front(); + Value inputC = buildNearestAsymmetricIndex( + outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, channelLoc); - auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc}); - rewriter.setInsertionPointToStart(channelLoop.getBody()); + auto heightLoop = buildNormalizedScfFor( + rewriter, + channelLoc, + c0, + cOutputH, + c1, + ValueRange {outputChannelAcc}, + [&](OpBuilder&, + Location heightLoc, + Value outputH, + ValueRange heightIterArgs, + SmallVectorImpl& heightYielded) { + Value outputHeightAcc = heightIterArgs.front(); + Value inputH = buildNearestAsymmetricIndex( + outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, heightLoc); - Value outputC = channelLoop.getInductionVar(); - Value outputChannelAcc = channelLoop.getRegionIterArgs().front(); - Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc); + auto widthLoop = buildNormalizedScfFor( + rewriter, + heightLoc, + c0, + cOutputW, + c1, + ValueRange {outputHeightAcc}, + [&](OpBuilder&, + Location widthLoc, + Value outputW, + ValueRange widthIterArgs, + SmallVectorImpl& widthYielded) { + Value outputWidthAcc = widthIterArgs.front(); + Value inputW = buildNearestAsymmetricIndex( + outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, widthLoc); - auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc}); - rewriter.setInsertionPointToStart(heightLoop.getBody()); + SmallVector inputOffsets = {inputN, inputC, inputH, inputW}; + Value inputSlice = tensor::ExtractSliceOp::create( + rewriter, widthLoc, unitTensorType, input, inputOffsets, unitSizes, unitStrides); - Value outputH = heightLoop.getInductionVar(); - Value outputHeightAcc = heightLoop.getRegionIterArgs().front(); - Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc); - - auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc}); - rewriter.setInsertionPointToStart(widthLoop.getBody()); - - Value outputW = widthLoop.getInductionVar(); - Value outputWidthAcc = widthLoop.getRegionIterArgs().front(); - Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc); - - SmallVector inputOffsets = {inputN, inputC, inputH, inputW}; - Value inputSlice = - tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides); - - SmallVector outputOffsets = {outputN, outputC, outputH, outputW}; - Value updatedOutput = - tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides); - scf::YieldOp::create(rewriter, loc, updatedOutput); - - rewriter.setInsertionPointAfter(widthLoop); - scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0)); - - rewriter.setInsertionPointAfter(heightLoop); - scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0)); - - rewriter.setInsertionPointAfter(channelLoop); - scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0)); - - rewriter.setInsertionPointAfter(batchLoop); - return batchLoop.getResult(0); + SmallVector outputOffsets = {outputN, outputC, outputH, outputW}; + Value updatedOutput = tensor::InsertSliceOp::create( + rewriter, widthLoc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides); + widthYielded.push_back(updatedOutput); + return success(); + }); + if (failed(widthLoop)) + return failure(); + heightYielded.push_back(widthLoop->results.front()); + return success(); + }); + if (failed(heightLoop)) + return failure(); + channelYielded.push_back(heightLoop->results.front()); + return success(); + }); + if (failed(channelLoop)) + return failure(); + batchYielded.push_back(channelLoop->results.front()); + return success(); + }); + if (failed(batchLoop)) + return failure(); + return batchLoop->results.front(); } struct Resize : OpConversionPattern { @@ -120,12 +161,17 @@ struct Resize : OpConversionPattern { || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions."); - auto computeOp = - createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { - Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc()); - spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); + auto computeOp = createSpatCompute<1>( + rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) -> LogicalResult { + auto result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc()); + if (failed(result)) + return failure(); + spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), *result); + return success(); }); - rewriter.replaceOp(resizeOp, computeOp.getResults()); + if (failed(computeOp)) + return failure(); + rewriter.replaceOp(resizeOp, computeOp->getResults()); return success(); } }; diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 3ce8f29..9888cf4 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -10,6 +10,7 @@ #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -25,14 +26,21 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) { }); } -static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { +static FailureOr> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, + size_t& fallbackCoreId) { if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); SmallVector coreIds; coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); - for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) - coreIds.push_back(static_cast(fallbackCoreId++)); + for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) { + auto checkedCoreId = + pim::checkedI32(static_cast(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id"); + if (failed(checkedCoreId)) + return failure(); + coreIds.push_back(*checkedCoreId); + ++fallbackCoreId; + } return coreIds; } @@ -102,21 +110,24 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute "resultful compute_batch lowering currently requires a spat.in_parallel terminator"); } - SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); + auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); + if (failed(coreIds)) + return failure(); SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector batchInputs; if (!computeBatchOp.getInputs().empty()) batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); rewriter.setInsertionPointAfter(computeBatchOp); - auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, - loc, - rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), - ValueRange(batchWeights), - ValueRange(batchInputs)); + auto laneCountAttr = pim::getCheckedI32Attr( + rewriter, computeBatchOp, static_cast(computeBatchOp.getLaneCount()), "pim core_batch lane count"); + if (failed(laneCountAttr)) + return failure(); + auto coreBatchOp = + pim::PimCoreBatchOp::create(rewriter, loc, *laneCountAttr, ValueRange(batchWeights), ValueRange(batchInputs)); coreBatchOp.getProperties().setOperandSegmentSizes( {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); - coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds)); SmallVector returnOperandIndices; if (computeBatchOp.getNumResults() != 0) { @@ -160,14 +171,11 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute auto newArgType = cast(newArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); - auto copied = pim::PimMemCopyHostToDevOp::create(rewriter, - loc, - outputBuffer.getType(), - zeroOffset, - zeroOffset, - outputBuffer, - newArg, - getTensorSizeInBytesAttr(rewriter, newArg)) + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), newArg); + if (failed(sizeAttr)) + return failure(); + auto copied = pim::PimMemCopyHostToDevOp::create( + rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, newArg, *sizeAttr) .getOutput(); mapper.map(*oldArg, copied); } @@ -209,6 +217,9 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute auto hostTargetType = cast(hostTarget.getType()); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource); + if (failed(sizeAttr)) + return failure(); pim::PimMemCopyDevToHostOp::create(rewriter, insertSlice.getLoc(), hostTarget.getType(), @@ -216,7 +227,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute zeroOffset, hostTarget, mappedSource, - getTensorSizeInBytesAttr(rewriter, mappedSource)); + *sizeAttr); } continue; } @@ -232,15 +243,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute auto clonedType = cast(clonedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); - auto copied = pim::PimMemCopyHostToDevOp::create(rewriter, - loc, - outputBuffer.getType(), - zeroOffset, - zeroOffset, - outputBuffer, - clonedTensor, - getTensorSizeInBytesAttr(rewriter, clonedTensor)) - .getOutput(); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), clonedTensor); + if (failed(sizeAttr)) + return failure(); + auto copied = + pim::PimMemCopyHostToDevOp::create( + rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, clonedTensor, *sizeAttr) + .getOutput(); mapper.map(toTensorOp.getResult(), copied); continue; } diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 4d304f4..72480c1 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -5,14 +5,18 @@ #include #include "Common.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" using namespace llvm; using namespace mlir; namespace onnx_mlir { -IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { - return builder.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(cast(value.getType())))); +FailureOr getTensorSizeInBytesAttr(Builder& builder, Operation* anchor, mlir::Value value) { + auto byteSize = pim::getCheckedShapedTypeSizeInBytes(cast(value.getType()), anchor, "tensor byte size"); + if (failed(byteSize)) + return failure(); + return pim::getCheckedI32Attr(builder, anchor, *byteSize, "tensor byte size"); } Operation* getEarliestUserWithinBlock(mlir::Value value) { diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index db87abe..49fe3ec 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -1,12 +1,14 @@ #pragma once #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" namespace onnx_mlir { -mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); +mlir::FailureOr +getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Operation* anchor, mlir::Value value); template size_t rangeLength(const mlir::iterator_range range) { diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 5388dfa..142bc9a 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -9,6 +9,7 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -54,10 +55,15 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite } } -static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { +static FailureOr getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - return static_cast(spatialCoreIdAttr.getInt()); - return static_cast(fallbackCoreId++); + return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id"); + auto checkedCoreId = + pim::checkedI32(static_cast(fallbackCoreId), computeOp, "fallback spatial compute core id"); + if (failed(checkedCoreId)) + return failure(); + ++fallbackCoreId; + return *checkedCoreId; } static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, @@ -163,10 +169,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg)); auto outputType = cast(blockArg->getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, computeOp.getOperation(), *blockArg); + if (failed(sizeAttr)) + return failure(); Value received = PimReceiveOp::create( - rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId()) + rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, *sizeAttr, receiveOp.getSourceCoreId()) .getOutput(); blockArg->replaceAllUsesWith(received); markOpToRemove(receiveOp); @@ -206,8 +214,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp if (!computeOp.getWeights().empty()) computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); rewriter.setInsertionPointAfter(computeOp); - auto coreOp = PimCoreOp::create( - rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); + auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId); + if (failed(checkedCoreId)) + return failure(); + auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast(*checkedCoreId), "pim core id"); + if (failed(coreIdAttr)) + return failure(); + auto coreOp = PimCoreOp::create(rewriter, loc, ValueRange(computeWeights), *coreIdAttr); rewriter.setInsertionPointToStart(&block); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { @@ -226,6 +239,9 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp if (!inputType) return computeOp.emitOpError("expected shaped compute input during pim.core lowering"); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, computeOp.getOperation(), input); + if (failed(sizeAttr)) + return failure(); auto copied = PimMemCopyHostToDevOp::create(rewriter, loc, @@ -234,7 +250,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0), outputBuffer, input, - getTensorSizeInBytesAttr(rewriter, input)) + *sizeAttr) .getOutput(); blockArg->replaceAllUsesWith(copied); } diff --git a/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp index 996add3..31d8823 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp @@ -14,8 +14,10 @@ struct ChannelSendLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override { - pim::PimSendOp::create( - rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId()); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput()); + if (failed(sizeAttr)) + return failure(); + pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId()); rewriter.eraseOp(op); return success(); } @@ -32,12 +34,11 @@ struct ChannelReceiveLowering : OpRewritePattern auto outputType = cast(op.getResult().getType()); Value outputBuffer = tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); - Value received = pim::PimReceiveOp::create(rewriter, - op.getLoc(), - op.getResult().getType(), - outputBuffer, - getTensorSizeInBytesAttr(rewriter, op.getResult()), - op.getSourceCoreId()) + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult()); + if (failed(sizeAttr)) + return failure(); + Value received = pim::PimReceiveOp::create( + rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId()) .getOutput(); rewriter.replaceOp(op, received); return success(); diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 32dcea3..11dccf4 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -12,6 +12,7 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -71,6 +72,20 @@ static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef +getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* anchor, StringRef fieldName) { + if (elementOffset < 0) { + anchor->emitOpError() << fieldName << " requires a nonnegative element offset"; + return failure(); + } + + auto byteOffset = + pim::checkedMul(static_cast(elementOffset), static_cast(elementSize), anchor, fieldName); + if (failed(byteOffset)) + return failure(); + return pim::checkedCast(*byteOffset, anchor, fieldName); +} + static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallVectorImpl& helperChain) { if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) @@ -360,18 +375,21 @@ static void cloneHelperChain(Value sourceValue, } } -static Value emitHostCopy(IRRewriter& rewriter, - Location loc, - Value outputTensor, - Value sourceValue, - int32_t hostTargetOffset, - int32_t deviceSourceOffset, - int32_t sizeInBytes, - OperationFolder& constantFolder) { +static FailureOr emitHostCopy(IRRewriter& rewriter, + Location loc, + Value outputTensor, + Value sourceValue, + int64_t hostTargetOffset, + int64_t deviceSourceOffset, + uint64_t sizeInBytes, + OperationFolder& constantFolder) { Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset); Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset); + auto sizeAttr = pim::getCheckedI32Attr(rewriter, anchorOp, sizeInBytes, "return-path host copy byte size"); + if (failed(sizeAttr)) + return failure(); return PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), @@ -379,7 +397,7 @@ static Value emitHostCopy(IRRewriter& rewriter, deviceSourceOffsetValue, outputTensor, sourceValue, - rewriter.getI32IntegerAttr(sizeInBytes)) + *sizeAttr) .getOutput(); } @@ -433,18 +451,15 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low markOpToRemove(op); auto storedType = cast(currentStoredValue.getType()); - size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType()); + auto byteSize = pim::getCheckedShapedTypeSizeInBytes(storedType, producerOp, "return-path host copy byte size"); + if (failed(byteSize)) + return ReturnPathLoweringResult::Failure; if (auto storedOp = currentStoredValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); - emitHostCopy(rewriter, - loc, - outputTensor, - currentStoredValue, - 0, - 0, - static_cast(storedType.getNumElements() * elementSize), - constantFolder); + auto copied = emitHostCopy(rewriter, loc, outputTensor, currentStoredValue, 0, 0, *byteSize, constantFolder); + if (failed(copied)) + return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Handled; } @@ -455,23 +470,25 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); - size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); + auto byteSize = + pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size"); + if (failed(byteSize)) + return ReturnPathLoweringResult::Failure; rewriter.setInsertionPointAfterValue(storedValue); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); - emitHostCopy(rewriter, - loc, - outputTensor, - storedValue, - 0, - 0, - static_cast(storedTensorType.getNumElements() * elementSize), - constantFolder); + auto copied = emitHostCopy(rewriter, loc, outputTensor, storedValue, 0, 0, *byteSize, constantFolder); + if (failed(copied)) + return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Handled; } } if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); + auto storedByteSize = + pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "concat return-path copy byte size"); + if (failed(storedByteSize)) + return ReturnPathLoweringResult::Failure; for (Operation* concatOp : concatReturnUse->concatChain) markOpToRemove(concatOp); @@ -480,14 +497,13 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); - emitHostCopy(rewriter, - loc, - outputTensor, - storedValue, - static_cast(flatOffset * elementSize), - 0, - static_cast(storedTensorType.getNumElements() * elementSize), - constantFolder); + auto hostOffset = getCheckedByteOffset(flatOffset, elementSize, producerOp, "concat return-path host offset"); + if (failed(hostOffset)) + return ReturnPathLoweringResult::Failure; + auto copied = + emitHostCopy(rewriter, loc, outputTensor, storedValue, *hostOffset, 0, *storedByteSize, constantFolder); + if (failed(copied)) + return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Handled; } @@ -531,14 +547,18 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low rewriter.setInsertionPointAfter(elementSlice); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); - outputTensor = emitHostCopy(rewriter, - loc, - outputTensor, - elementSlice.getResult(), - static_cast(destinationFlatOffset * elementSize), - 0, - static_cast(elementSize), - constantFolder); + auto hostOffset = + getCheckedByteOffset(destinationFlatOffset, elementSize, producerOp, "concat helper return-path host offset"); + if (failed(hostOffset)) + return ReturnPathLoweringResult::Failure; + auto elementByteSize = pim::checkedCast(elementSize, producerOp, "return-path scalar copy byte size"); + if (failed(elementByteSize)) + return ReturnPathLoweringResult::Failure; + auto copied = emitHostCopy( + rewriter, loc, outputTensor, elementSlice.getResult(), *hostOffset, 0, *elementByteSize, constantFolder); + if (failed(copied)) + return ReturnPathLoweringResult::Failure; + outputTensor = *copied; } return ReturnPathLoweringResult::Handled; } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 957e522..f057626 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -25,8 +25,9 @@ #include #include -#include "Common/PimCommon.hpp" #include "Common/IR/ConstantUtils.hpp" +#include "Common/PimCommon.hpp" +#include "Common/Support/CheckedArithmetic.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Patterns.hpp" @@ -75,21 +76,28 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc IntegerAttr {}); } -static Value createZeroedDeviceHVector(IRRewriter& rewriter, - Location loc, - RankedTensorType tensorType, - OperationFolder& constantFolder) { +static FailureOr createZeroedDeviceHVector(IRRewriter& rewriter, + Location loc, + RankedTensorType tensorType, + OperationFolder& constantFolder) { auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0); - auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); + auto byteSize = + pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size"); + if (failed(byteSize)) + return failure(); + auto sizeAttr = + pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size"); + if (failed(sizeAttr)) + return failure(); return PimMemCopyHostToDevOp::create( - rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr) + rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr) .getOutput(); } -static Value +static FailureOr padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { auto vectorType = cast(vector.getType()); ArrayRef shape = vectorType.getShape(); @@ -101,10 +109,18 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, auto paddedType = RankedTensorType::get( {shape[0], static_cast(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); - Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); - Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0); - auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(vectorType))); - return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput(); + auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); + if (failed(zeroed)) + return failure(); + Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0); + auto byteSize = + pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size"); + if (failed(byteSize)) + return failure(); + auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size"); + if (failed(sizeAttr)) + return failure(); + return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput(); } void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { @@ -234,7 +250,11 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } } - enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); + if (failed(enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter))) { + funcOp.emitOpError("failed to enlarge VMM output tensors to crossbar size"); + signalPassFailure(); + return; + } replaceReturnWithOutputBuffers(returnOp, rewriter); eraseOpsToRemove(); @@ -271,8 +291,9 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { dumpModule(moduleOp, "pim0"); } -void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { +LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { OperationFolder constantFolder(funcOp.getContext()); + bool hasFailure = false; funcOp.walk([&](PimVMMOp vmmOp) { auto outputType = cast(vmmOp.getOutput().getType()); ArrayRef outputShape = outputType.getShape(); @@ -280,19 +301,23 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f assert(outputShape[1] <= static_cast(crossbarSize) && "output width must fit in one crossbar"); rewriter.setInsertionPoint(vmmOp); - Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); + auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); + if (failed(paddedInput)) { + hasFailure = true; + return WalkResult::interrupt(); + } auto paddedOutputType = RankedTensorType::get( {outputShape[0], static_cast(crossbarSize)}, outputType.getElementType(), outputType.getEncoding()); Value paddedOutputBuffer = outputShape[1] == static_cast(crossbarSize) ? vmmOp.getOutputBuffer() : createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult(); - vmmOp.getInputMutable().assign(paddedInput); + vmmOp.getInputMutable().assign(*paddedInput); vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer); vmmOp.getOutput().setType(paddedOutputType); if (outputShape[1] == static_cast(crossbarSize)) - return; + return WalkResult::advance(); SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])}; @@ -302,13 +327,16 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions); + return WalkResult::advance(); }); + return success(!hasFailure); } LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); OperationFolder constantFolder(funcOp.getContext()); + bool hasFailure = false; auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); @@ -319,17 +347,28 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables( rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); + auto offsetBytes = pim::checkedMul( + static_cast(elementsOffset), elementByteSize, deviceTensor.getOperation(), "host input byte offset"); + auto byteSize = + pim::getCheckedShapedTypeSizeInBytes(tensorType, deviceTensor.getOperation(), "host input copy byte size"); + auto sizeAttr = + succeeded(byteSize) + ? pim::getCheckedI32Attr(rewriter, deviceTensor.getOperation(), *byteSize, "host input copy byte size") + : FailureOr(failure()); + if (failed(offsetBytes) || failed(sizeAttr)) { + hasFailure = true; + return; + } auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0), - getOrCreateIndexConstant( - constantFolder, deviceTensor.getOperation(), static_cast(elementsOffset * elementByteSize)), + getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), static_cast(*offsetBytes)), deviceTensor, inputTensor, - rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); + *sizeAttr); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; @@ -347,7 +386,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables( } } - return success(); + return success(!hasFailure); } void raptor::SpatialToPimPass::markOpToRemove(Operation* op) { diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp index 010aadc..7508f3c 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp @@ -64,7 +64,7 @@ private: void markOpToRemove(mlir::Operation* op); void eraseOpsToRemove(); - void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); + mlir::LogicalResult enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); }; } // namespace raptor diff --git a/src/PIM/Dialect/Pim/CMakeLists.txt b/src/PIM/Dialect/Pim/CMakeLists.txt index 2ad22ca..0e3d18d 100644 --- a/src/PIM/Dialect/Pim/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/CMakeLists.txt @@ -2,7 +2,10 @@ add_onnx_mlir_dialect(Pim pim) add_onnx_mlir_dialect_doc(pim Pim.td) add_subdirectory(Transforms/Bufferization) -add_subdirectory(Transforms/StaticMemoryCoalescing) +add_subdirectory(Transforms/MemoryCoalescing) +add_subdirectory(Transforms/HostConstantFolding) +add_subdirectory(Transforms/HostConstantMaterialization) +add_subdirectory(Transforms/Verification) add_pim_library(PimOps PimOps.hpp diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index cb14469..260389b 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -3,6 +3,7 @@ #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" @@ -11,24 +12,25 @@ using namespace bufferization; namespace onnx_mlir::pim { -Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { +FailureOr materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) return memrefValue; auto shapedType = cast(memrefValue.getType()); auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); - auto sizeInBytes = getShapedTypeSizeInBytes(shapedType); + auto sizeInBytes = + getCheckedShapedTypeSizeInBytes(shapedType, contiguousBuffer.getDefiningOp(), "contiguous copy byte size"); + if (failed(sizeInBytes)) + return failure(); Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0); + auto sizeAttr = + getCheckedI32Attr(rewriter, contiguousBuffer.getDefiningOp(), *sizeInBytes, "contiguous copy byte size"); + if (failed(sizeAttr)) + return failure(); - return PimMemCopyOp::create(rewriter, - loc, - contiguousType, - zeroOffset, - zeroOffset, - contiguousBuffer, - memrefValue, - rewriter.getI32IntegerAttr(sizeInBytes)) + return PimMemCopyOp::create( + rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr) .getOutput(); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp index d9bc034..590afec 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp @@ -5,7 +5,8 @@ namespace onnx_mlir::pim { -mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); +llvm::FailureOr +materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); mlir::Value allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp index 395806d..321347e 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp @@ -1,10 +1,13 @@ #include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" using namespace mlir; -IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) { +FailureOr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) { auto type = mlir::cast(memref.getType()); - int32_t sizeInBytes = static_cast(getShapedTypeSizeInBytes(type)); - return builder.getI32IntegerAttr(sizeInBytes); + auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size"); + if (failed(byteSize)) + return failure(); + return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size"); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp index 961f724..6816d62 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp @@ -5,7 +5,8 @@ namespace onnx_mlir { namespace pim { -mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref); +mlir::FailureOr +getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref); } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp index 9e00bde..995b5a7 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.cpp @@ -4,8 +4,10 @@ #include "ContiguityPatterns.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" using namespace mlir; @@ -85,7 +87,13 @@ static FailureOr> getStaticMemRefStrides(MemRefType type) { static FailureOr getShapedByteSize(MemRefType type) { if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())) return failure(); - return static_cast(getShapedTypeSizeInBytes(type)); + auto byteSize = + pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "normalized copy byte size"); + if (failed(byteSize)) + return failure(); + if (*byteSize > static_cast(std::numeric_limits::max())) + return failure(); + return static_cast(*byteSize); } static FailureOr> @@ -325,12 +333,11 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp, if (plan->kind == CopyRewritePlan::Kind::Direct) { Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset); Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset); - auto newCopyOp = createCopyOp(loc, - plan->target.base, - plan->source.base, - newTargetOffset, - newSourceOffset, - static_cast(plan->directBytes)); + auto checkedDirectBytes = pim::checkedI32(plan->directBytes, anchorOp, "normalized direct copy byte size"); + if (failed(checkedDirectBytes)) + return failure(); + auto newCopyOp = + createCopyOp(loc, plan->target.base, plan->source.base, newTargetOffset, newSourceOffset, *checkedDirectBytes); assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); rewriter.replaceOp(copyOp, replacementValue); return success(); @@ -339,23 +346,30 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp, Value c0 = createIndexConstant(rewriter, anchorOp, 0); Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape)); Value cStep = createIndexConstant(rewriter, anchorOp, 1); - auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {}); - rewriter.setInsertionPointToStart(loop.getBody()); - - SmallVector outerIndices = - materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape); - Value loopTargetOffset = materializeOuterByteOffset( - rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides); - Value loopSourceOffset = materializeOuterByteOffset( - rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides); - auto newCopyOp = createCopyOp(loc, - plan->target.base, - plan->source.base, - loopTargetOffset, - loopSourceOffset, - static_cast(plan->loop.chunkBytes)); - assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); - rewriter.setInsertionPointAfter(loop); + auto loop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cUpper, + cStep, + ValueRange {}, + [&](OpBuilder&, Location nestedLoc, Value inductionVar, ValueRange iterArgs, SmallVectorImpl& yielded) { + SmallVector outerIndices = + materializeDelinearizedIndices(rewriter, nestedLoc, anchorOp, inductionVar, plan->loop.outerShape); + Value loopTargetOffset = materializeOuterByteOffset( + rewriter, nestedLoc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides); + Value loopSourceOffset = materializeOuterByteOffset( + rewriter, nestedLoc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides); + auto checkedChunkBytes = pim::checkedI32(plan->loop.chunkBytes, anchorOp, "normalized loop copy byte size"); + if (failed(checkedChunkBytes)) + return failure(); + auto newCopyOp = createCopyOp( + nestedLoc, plan->target.base, plan->source.base, loopTargetOffset, loopSourceOffset, *checkedChunkBytes); + assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); + return success(); + }); + if (failed(loop)) + return failure(); rewriter.replaceOp(copyOp, replacementValue); return success(); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 0777fb2..db77dfa 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -148,7 +148,10 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter)); + auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + if (failed(contiguous)) + return failure(); + inputs.push_back(*contiguous); } auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state); @@ -179,12 +182,12 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModelgetLoc(), rewriter); + if (failed(contiguousInput)) + return failure(); - replaceOpWithNewBufferizedOp(rewriter, - op, - materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter), - sendOp.getSizeAttr(), - sendOp.getTargetCoreId()); + replaceOpWithNewBufferizedOp( + rewriter, op, *contiguousInput, sendOp.getSizeAttr(), sendOp.getTargetCoreId()); return success(); } }; @@ -407,11 +410,13 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + if (failed(contiguousInput)) + return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput); + rewriter, op, contiguousOutput.getType(), *contiguousInput, transposeOp.getPermutation(), contiguousOutput); return success(); } }; @@ -451,11 +456,13 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + if (failed(contiguousInput)) + return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput); + rewriter, op, contiguousOutput.getType(), *weightOpt, *contiguousInput, contiguousOutput); return success(); } }; @@ -490,12 +497,16 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); + if (failed(contiguousLhs)) + return failure(); + auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + if (failed(contiguousRhs)) + return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); + rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput); return success(); } }; @@ -523,12 +534,16 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); - Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); + if (failed(contiguousLhs)) + return failure(); + auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); + if (failed(contiguousRhs)) + return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); + rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput); return success(); } }; @@ -559,10 +574,12 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); + if (failed(contiguousInput)) + return failure(); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); - replaceOpWithNewBufferizedOp(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput); + replaceOpWithNewBufferizedOp(rewriter, op, contiguousOutput.getType(), *contiguousInput, contiguousOutput); return success(); } }; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index aa53879..01e35c4 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -42,7 +42,9 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern { return failure(); Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0); - IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource()); + auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource()); + if (failed(sizeAttr)) + return failure(); pim::PimMemCopyOp::create(rewriter, copyOp.getLoc(), copyOp.getTarget().getType(), @@ -50,7 +52,7 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern { zeroOffset, copyOp.getTarget(), copyOp.getSource(), - sizeAttr); + *sizeAttr); rewriter.eraseOp(copyOp); return success(); } diff --git a/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/CMakeLists.txt new file mode 100644 index 0000000..7b815ab --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/CMakeLists.txt @@ -0,0 +1,12 @@ +add_pim_library(OMPimHostConstantFolding + Common.cpp + Patterns/Constant.cpp + HostConstantFoldingPass.cpp + Patterns/Subview.cpp + + EXCLUDE_FROM_OM_LIBS + + LINK_LIBS PUBLIC + MLIRLinalgDialect + OMPimCommon +) diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Common.cpp similarity index 100% rename from src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Common.cpp diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Common.hpp similarity index 100% rename from src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Common.hpp diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/HostConstantFoldingPass.cpp similarity index 100% rename from src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantFolding/HostConstantFoldingPass.cpp diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns.hpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns.hpp similarity index 100% rename from src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns.hpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns.hpp diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp similarity index 97% rename from src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp index a72db50..2fa900f 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp @@ -7,6 +7,7 @@ #include "../Common.hpp" #include "../Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -120,16 +121,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { rewriter.setInsertionPoint(mapOp); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); - auto sizeInBytes = getShapedTypeSizeInBytes(initType); + auto sizeInBytes = pim::getCheckedShapedTypeSizeInBytes(initType, mapOp, "host constant folding byte size"); + if (failed(sizeInBytes)) + return failure(); Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0); - pim::PimMemCopyOp::create(rewriter, - mapOp.getLoc(), - initType, - zeroOffset, - zeroOffset, - mapOp.getInit(), - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(sizeInBytes)); + auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size"); + if (failed(sizeAttr)) + return failure(); + pim::PimMemCopyOp::create( + rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr); rewriter.eraseOp(mapOp); return success(); } diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Subview.cpp similarity index 100% rename from src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Subview.cpp diff --git a/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt new file mode 100644 index 0000000..c2105c2 --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt @@ -0,0 +1,9 @@ +add_pim_library(OMPimHostConstantMaterialization + MaterializeHostConstantsPass.cpp + + EXCLUDE_FROM_OM_LIBS + + LINK_LIBS PUBLIC + OMPimCommon + PimOps +) diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/MaterializeHostConstantsPass.cpp similarity index 79% rename from src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp rename to src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/MaterializeHostConstantsPass.cpp index d7b3c88..97850a3 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/MaterializeHostConstantsPass.cpp @@ -12,6 +12,7 @@ #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -65,11 +66,15 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, continue; } - int64_t totalBytes = -1; - if (auto type = dyn_cast(originalValue.getType()); type && type.hasStaticShape()) - totalBytes = static_cast(getShapedTypeSizeInBytes(type)); - if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { - op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); + auto type = dyn_cast(originalValue.getType()); + auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size") + : FailureOr(failure()); + auto totalBytesAttr = + succeeded(totalBytes) + ? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size") + : FailureOr(failure()); + if (failed(totalBytesAttr) + || failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) { hasFailure = true; continue; } @@ -84,16 +89,15 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0); Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset); - Value copiedValue = - pim::PimMemCopyHostToDevOp::create(rewriter, - op->getLoc(), - originalType, - zeroOffset, - hostOffset, - deviceDst, - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(static_cast(totalBytes))) - .getOutput(); + Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter, + op->getLoc(), + originalType, + zeroOffset, + hostOffset, + deviceDst, + getGlobalOp.getResult(), + *totalBytesAttr) + .getOutput(); cachedByType[originalType] = copiedValue; operand.set(copiedValue); diff --git a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/CMakeLists.txt new file mode 100644 index 0000000..d746087 --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/CMakeLists.txt @@ -0,0 +1,14 @@ +add_pim_library(OMPimMemoryCoalescing + MemoryCoalescing.cpp + MemoryCoalescing.hpp + MemoryCoalescingPass.cpp + + EXCLUDE_FROM_OM_LIBS + + INCLUDE_DIRS PUBLIC + ${PIM_PUBLIC_INCLUDE_DIRS} + + LINK_LIBS PUBLIC + OMPimCommon + PimOps +) diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp similarity index 77% rename from src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp rename to src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp index 1e49538..4153489 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp @@ -10,7 +10,7 @@ #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp" using namespace mlir; @@ -32,6 +32,13 @@ static uint64_t getTypeSizeBytes(MemRefType type) { return static_cast(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType())); } +static Operation* getTopLevelAncestorInBody(Operation* op, Block& body) { + Operation* current = op; + while (current && current->getBlock() != &body) + current = current->getParentOp(); + return current; +} + static FailureOr getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap& opOrder) { uint64_t endInstruction = opOrder.lookup(allocOp); @@ -42,7 +49,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMapgetBlock() != &body) + Operation* orderedUser = getTopLevelAncestorInBody(user, body); + if (!orderedUser) return failure(); if (!visited.insert(user).second) continue; @@ -51,6 +59,15 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMapgetResults()) pendingValues.push_back(result); + if (auto yieldOp = dyn_cast(user)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return failure(); + for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) + if (operand == value) + pendingValues.push_back(forOp.getResult(index)); + } + if (auto forOp = dyn_cast(user)) { for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) if (initArg == value) @@ -66,7 +83,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMapsecond); @@ -78,8 +95,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMapgetNumRegions() != 1 || coreLikeOp->getRegion(0).empty()) return analysis; @@ -107,18 +124,19 @@ StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation } analysis.candidates.push_back( - StaticAllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)}); + AllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)}); } return analysis; } -StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, RewriterBase& rewriter) { - StaticMemoryCoalescingStats stats; - auto analysis = analyzeStaticMemoryCoalescingCandidates(coreLikeOp); +MemoryCoalescingStats +coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) { + MemoryCoalescingStats stats; stats.skippedAllocations = analysis.skippedAllocations; - llvm::sort(analysis.candidates, [](const StaticAllocationCandidate& lhs, const StaticAllocationCandidate& rhs) { + auto candidates = analysis.candidates; + llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) { if (lhs.startInstruction != rhs.startInstruction) return lhs.startInstruction < rhs.startInstruction; return lhs.endInstruction < rhs.endInstruction; @@ -132,7 +150,7 @@ StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, Rewriter SmallVector active; SmallVector freeList; - for (StaticAllocationCandidate& candidate : analysis.candidates) { + for (AllocationCandidate& candidate : candidates) { for (auto it = active.begin(); it != active.end();) { if (it->endInstruction < candidate.startInstruction) { freeList.push_back(it->root); diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp similarity index 55% rename from src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp rename to src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp index 3c656f6..e0b4025 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp @@ -8,27 +8,28 @@ namespace onnx_mlir { namespace pim { -struct StaticAllocationCandidate { +struct AllocationCandidate { mlir::memref::AllocOp alloc; uint64_t startInstruction = 0; uint64_t endInstruction = 0; uint64_t sizeBytes = 0; }; -struct StaticMemoryCoalescingAnalysis { - llvm::SmallVector candidates; +struct MemoryCoalescingAnalysis { + llvm::SmallVector candidates; uint64_t skippedAllocations = 0; }; -struct StaticMemoryCoalescingStats { +struct MemoryCoalescingStats { uint64_t removedAllocs = 0; uint64_t savedBytes = 0; uint64_t skippedAllocations = 0; }; -StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(mlir::Operation* coreLikeOp); +MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(mlir::Operation* coreLikeOp); -StaticMemoryCoalescingStats coalesceStaticMemory(mlir::Operation* coreLikeOp, mlir::RewriterBase& rewriter); +MemoryCoalescingStats +coalesceMemory(mlir::Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, mlir::RewriterBase& rewriter); } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp similarity index 82% rename from src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp rename to src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp index ec7647a..5422a14 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp @@ -10,10 +10,11 @@ #include "Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" using namespace mlir; @@ -151,34 +152,39 @@ static void emitReport(ArrayRef entries) { file.close(); } -struct StaticMemoryCoalescingPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticMemoryCoalescingPass) +struct PimMemoryCoalescingPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMemoryCoalescingPass) - StringRef getArgument() const override { return "pim-static-memory-coalescing"; } - StringRef getDescription() const override { return "Analyze static local PIM memory reuse opportunities"; } + StringRef getArgument() const override { return "pim-memory-coalescing"; } + StringRef getDescription() const override { return "Analyze local PIM memory reuse opportunities"; } - StaticMemoryCoalescingPass() = default; - StaticMemoryCoalescingPass(const StaticMemoryCoalescingPass& pass) {} + PimMemoryCoalescingPass() = default; + PimMemoryCoalescingPass(const PimMemoryCoalescingPass& pass) {} void runOnOperation() override { IRRewriter rewriter(&getContext()); SmallVector reportEntries; uint64_t nextBatchId = 0; + bool hasFailure = false; getOperation().walk([&](Operation* op) { - if (!isa(op)) + if (hasFailure || !isa(op)) return; - auto analysis = pim::analyzeStaticMemoryCoalescingCandidates(op); - auto stats = pim::coalesceStaticMemory(op, rewriter); + auto analysis = pim::analyzeMemoryCoalescingCandidates(op); + auto stats = pim::coalesceMemory(op, analysis, rewriter); CoalescingReportRow row { analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes}; if (auto coreOp = dyn_cast(op)) { - reportEntries.push_back({CoalescingReportEntry::Kind::Core, - static_cast(coreOp.getCoreId()), - {static_cast(coreOp.getCoreId())}, - row}); + auto checkedCoreId = + pim::checkedI32(static_cast(coreOp.getCoreId()), coreOp, "memory coalescing core id"); + if (failed(checkedCoreId)) { + hasFailure = true; + return; + } + reportEntries.push_back( + {CoalescingReportEntry::Kind::Core, static_cast(coreOp.getCoreId()), {*checkedCoreId}, row}); return; } @@ -191,6 +197,11 @@ struct StaticMemoryCoalescingPass : PassWrapper createPimStaticMemoryCoalescingPass() { return std::make_unique(); } +std::unique_ptr createPimMemoryCoalescingPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/CMakeLists.txt deleted file mode 100644 index 916b12f..0000000 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -add_pim_library(OMPimStaticMemoryCoalescing - StaticMemoryCoalescing.cpp - StaticMemoryCoalescing.hpp - StaticMemoryCoalescingPass.cpp - - EXCLUDE_FROM_OM_LIBS - - INCLUDE_DIRS PUBLIC - ${PIM_PUBLIC_INCLUDE_DIRS} - - LINK_LIBS PUBLIC - OMPimCommon - PimOps -) diff --git a/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt new file mode 100644 index 0000000..67009ac --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/Verification/CMakeLists.txt @@ -0,0 +1,11 @@ +add_pim_library(OMPimVerification + VerificationPass.cpp + + EXCLUDE_FROM_OM_LIBS + + LINK_LIBS PUBLIC + OMPimCommon + OMPimBufferization + PimOps + SpatialOps +) diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp similarity index 100% rename from src/PIM/Pass/PimCodegen/VerificationPass.cpp rename to src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 4a56f9e..06df444 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -23,7 +23,9 @@ #include "Scheduling/ComputeInstanceUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -36,6 +38,23 @@ using CpuId = size_t; using ClassId = size_t; using SlotId = size_t; +static FailureOr getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) { + return pim::checkedI32(static_cast(cpu), anchor, fieldName); +} + +static FailureOr> +getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) { + SmallVector coreIds; + coreIds.reserve(cpus.size()); + for (CpuId cpu : cpus) { + auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); + if (failed(checkedCoreId)) + return failure(); + coreIds.push_back(*checkedCoreId); + } + return coreIds; +} + struct ProducerKey { ComputeInstance instance; size_t resultIndex = 0; @@ -498,7 +517,7 @@ LogicalResult collectHostOutputs(MaterializerState& state) { return success(); } -void createEmptyMaterializedOps(MaterializerState& state) { +LogicalResult createEmptyMaterializedOps(MaterializerState& state) { Location loc = state.func.getLoc(); Block& funcBlock = state.func.getBody().front(); @@ -524,8 +543,11 @@ void createEmptyMaterializedOps(MaterializerState& state) { if (!materializedClass.isBatch) { auto compute = SpatCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); compute.getProperties().setOperandSegmentSizes({0, 0}); - compute->setAttr(onnx_mlir::kCoreIdAttrName, - state.rewriter.getI32IntegerAttr(static_cast(materializedClass.cpus.front()))); + auto coreIdAttr = + pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); + if (failed(coreIdAttr)) + return failure(); + compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); Block* body = state.rewriter.createBlock(&compute.getBody()); state.rewriter.setInsertionPointToEnd(body); SmallVector placeholderOutputs; @@ -534,7 +556,7 @@ void createEmptyMaterializedOps(MaterializerState& state) { auto tensorType = dyn_cast(resultType); if (!tensorType || !tensorType.hasStaticShape()) { compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); - continue; + return failure(); } placeholderOutputs.push_back( tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); @@ -546,19 +568,17 @@ void createEmptyMaterializedOps(MaterializerState& state) { continue; } - auto batch = - SpatComputeBatch::create(state.rewriter, - loc, - TypeRange(resultTypes), - state.rewriter.getI32IntegerAttr(static_cast(materializedClass.cpus.size())), - ValueRange {}, - ValueRange {}); + auto batchLaneCountAttr = pim::getCheckedI32Attr( + state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); + if (failed(batchLaneCountAttr)) + return failure(); + auto batch = SpatComputeBatch::create( + state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); batch.getProperties().setOperandSegmentSizes({0, 0}); - SmallVector coreIds; - coreIds.reserve(materializedClass.cpus.size()); - for (CpuId cpu : materializedClass.cpus) - coreIds.push_back(static_cast(cpu)); - batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(coreIds)); + auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); + if (failed(coreIds)) + return failure(); + batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(*coreIds)); SmallVector blockArgTypes {state.rewriter.getIndexType()}; SmallVector blockArgLocs {loc}; @@ -575,6 +595,8 @@ void createEmptyMaterializedOps(MaterializerState& state) { materializedClass.body = body; state.rewriter.setInsertionPointAfter(batch.getOperation()); } + + return success(); } BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { @@ -787,15 +809,15 @@ SmallVector widenToI64(ArrayRef values) { return widened; } -Value createReceiveConcatLoop(MaterializerState& state, - Operation* anchor, - Operation* insertionPoint, - RankedTensorType concatType, - RankedTensorType fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc); +FailureOr createReceiveConcatLoop(MaterializerState& state, + Operation* anchor, + Operation* insertionPoint, + RankedTensorType concatType, + RankedTensorType fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc); FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, MaterializedClass& targetClass, @@ -853,15 +875,18 @@ FailureOr materializePackedScalarRunValue(MaterializerState& state, SmallVector sourceCoreIds = widenToI64(run.sourceCoreIds); SmallVector targetCoreIds = widenToI64(run.targetCoreIds); - run.packed = createReceiveConcatLoop(state, - targetClass.op, - targetClass.body->getTerminator(), - *fullPackedType, - run.fragmentType, - run.channelIds, - sourceCoreIds, - targetCoreIds, - loc); + auto packed = createReceiveConcatLoop(state, + targetClass.op, + targetClass.body->getTerminator(), + *fullPackedType, + run.fragmentType, + run.channelIds, + sourceCoreIds, + targetCoreIds, + loc); + if (failed(packed)) + return failure(); + run.packed = *packed; return run.packed; } @@ -1559,13 +1584,13 @@ void appendScalarSend(MaterializerState& state, SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); } -void appendScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { +LogicalResult appendScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); assert(channelIds.size() > 1 && "send loop is only useful for multiple sends"); assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); @@ -1578,25 +1603,31 @@ void appendScalarSendLoop(MaterializerState& state, getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(channelIds.size())); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToStart(loop.getBody()); - - Value index = loop.getInductionVar(); - Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); - - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + auto sendLoop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + return success(); + }); + if (failed(sendLoop)) + return failure(); + return success(); } -Value buildProjectedPackedPayload(MaterializerState& state, - Operation* anchor, - Value fullPayload, - const ProjectedTransferDescriptor& descriptor, - Value laneIndex, - Location loc) { +FailureOr buildProjectedPackedPayload(MaterializerState& state, + Operation* anchor, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value laneIndex, + Location loc) { assert(descriptor.fragmentsPerLane > 1 && "use direct fragment path for single-fragment projection"); Value init = tensor::EmptyOp::create( @@ -1607,42 +1638,42 @@ Value buildProjectedPackedPayload(MaterializerState& state, Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); + Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); + Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorProjectedOffsets, flatIndex, loc); + Value fragment = createSingleDimExtractSlice( + state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape()); - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value fragmentIndex = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); - Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - - Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorProjectedOffsets, flatIndex, loc); - Value fragment = createSingleDimExtractSlice( - state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape()); - - Value packedOffset = scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc); - Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset); - scf::YieldOp::create(state.rewriter, loc, next); - - return loop.getResult(0); + Value packedOffset = + scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc); + Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); } -void appendProjectedScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ProjectedTransferDescriptor& descriptor, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { +LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ProjectedTransferDescriptor& descriptor, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); @@ -1664,11 +1695,14 @@ void appendProjectedScalarSendLoop(MaterializerState& state, state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); } else { - sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc); + auto packedPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc); + if (failed(packedPayload)) + return failure(); + sendPayload = *packedPayload; } SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); - return; + return success(); } Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); @@ -1676,36 +1710,46 @@ void appendProjectedScalarSendLoop(MaterializerState& state, getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(channelIds.size())); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); + auto projectedSendLoop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToStart(loop.getBody()); + Value sendPayload; + if (descriptor.fragmentsPerLane == 1) { + Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorProjectedOffsets, index, loc); + sendPayload = createSingleDimExtractSlice( + state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); + } + else { + auto packedPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc); + if (failed(packedPayload)) + return failure(); + sendPayload = *packedPayload; + } - Value index = loop.getInductionVar(); - Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); - - Value sendPayload; - if (descriptor.fragmentsPerLane == 1) { - Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorProjectedOffsets, index, loc); - sendPayload = createSingleDimExtractSlice( - state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); - } - else { - sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc); - } - - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); + return success(); + }); + if (failed(projectedSendLoop)) + return failure(); + return success(); } -void appendSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { +LogicalResult appendSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); assert(!channelIds.empty() && "expected at least one send"); @@ -1717,16 +1761,16 @@ void appendSend(MaterializerState& state, Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - return; + return success(); } if (channelIds.size() == 1) { appendScalarSend( state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); - return; + return success(); } - appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + return appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); } Value appendScalarReceive(MaterializerState& state, @@ -1854,15 +1898,17 @@ SmallVector collectDestinationClassesForKeys(MaterializerState& stat return destinations; } -SmallVector emitScalarSourceSends(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - ArrayRef destinationClasses, - Value payload, - Location loc) { +FailureOr> emitScalarSourceSends(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + ArrayRef destinationClasses, + Value payload, + Location loc) { assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); - int32_t sourceCpu = static_cast(sourceClass.cpus.front()); + auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id"); + if (failed(sourceCpu)) + return failure(); SmallVector receivePlans; receivePlans.reserve(destinationClasses.size()); @@ -1870,7 +1916,7 @@ SmallVector emitScalarSourceSends(MaterializerState& const auto tryEmitProjected = [&](ClassId destinationClass, const SmallVector& channelIds, const SmallVector& sourceCoreIds, - const SmallVector& targetCoreIds) -> bool { + const SmallVector& targetCoreIds) -> FailureOr { if (keys.size() != 1) return false; @@ -1891,8 +1937,9 @@ SmallVector emitScalarSourceSends(MaterializerState& != targetClass.cpus.size() * static_cast(descriptor.fragmentsPerLane)) return false; - appendProjectedScalarSendLoop( - state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc); + if (failed(appendProjectedScalarSendLoop( + state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc))) + return failure(); Value received = appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc); @@ -1911,24 +1958,36 @@ SmallVector emitScalarSourceSends(MaterializerState& ScalarSourceReceivePlan plan; plan.targetClass = destinationClass; - auto appendMessage = [&](int32_t targetCpu) { + auto appendMessage = [&](CpuId targetCpu) -> LogicalResult { + auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id"); + if (failed(checkedTargetCpu)) + return failure(); int64_t channelId = state.nextChannelId++; plan.channelIds.push_back(channelId); - plan.sourceCoreIds.push_back(sourceCpu); - plan.targetCoreIds.push_back(targetCpu); + plan.sourceCoreIds.push_back(*sourceCpu); + plan.targetCoreIds.push_back(*checkedTargetCpu); + return success(); }; - if (!targetClass.isBatch) - appendMessage(static_cast(targetClass.cpus.front())); - else + if (!targetClass.isBatch) { + if (failed(appendMessage(targetClass.cpus.front()))) + return failure(); + } + else { for (CpuId targetCpu : targetClass.cpus) - appendMessage(static_cast(targetCpu)); + if (failed(appendMessage(targetCpu))) + return failure(); + } - if (tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds)) + auto emittedProjected = tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds); + if (failed(emittedProjected)) + return failure(); + if (*emittedProjected) continue; - appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); + if (failed(appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc))) + return failure(); receivePlans.push_back(std::move(plan)); } @@ -1943,10 +2002,11 @@ LogicalResult emitScalarSourceCommunication( state.availableValues.record(key, sourceClass.id, payload); SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); - SmallVector receivePlans = - emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc); + auto receivePlans = emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc); + if (failed(receivePlans)) + return failure(); - for (const ScalarSourceReceivePlan& plan : receivePlans) { + for (const ScalarSourceReceivePlan& plan : *receivePlans) { MaterializedClass& targetClass = state.classes[plan.targetClass]; Value received = appendReceive( @@ -1987,14 +2047,20 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, sourceCoreIds.reserve(sourceClass.cpus.size()); targetCoreIds.reserve(sourceClass.cpus.size()); - int32_t targetCpu = static_cast(targetClass.cpus.front()); + auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); + if (failed(targetCpu)) + return failure(); for (CpuId sourceCpu : sourceClass.cpus) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); + if (failed(checkedSourceCpu)) + return failure(); channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(static_cast(sourceCpu)); - targetCoreIds.push_back(targetCpu); + sourceCoreIds.push_back(*checkedSourceCpu); + targetCoreIds.push_back(*targetCpu); } - appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + if (failed(appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc))) + return failure(); return registerLazyPackedScalarReceives( state, sourceClass, targetClass, keys, payload.getType(), channelIds, sourceCoreIds, targetCoreIds); } @@ -2011,12 +2077,19 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, targetCoreIds.reserve(targetClass.cpus.size()); for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id"); + if (failed(checkedTargetCpu)) + return failure(); channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(static_cast(sourceCpu)); - targetCoreIds.push_back(static_cast(targetClass.cpus[lane])); + sourceCoreIds.push_back(*checkedSourceCpu); + targetCoreIds.push_back(*checkedTargetCpu); } - appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + if (failed(appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc))) + return failure(); Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); for (ProducerKey key : keys) @@ -2230,35 +2303,30 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); + FailureOr> produced = + cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); + if (failed(produced) || produced->size() != 1) + return failure(); - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value loopIndex = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); - if (failed(produced)) + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, run.fragmentType.getDimSize(0), loc); + Value next = createDim0InsertSlice(state, loc, produced->front(), acc, firstOffset); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) return failure(); - - if (produced->size() != 1) - return failure(); - - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, run.fragmentType.getDimSize(0), loc); - Value next = createDim0InsertSlice(state, loc, produced->front(), acc, firstOffset); - - scf::YieldOp::create(state.rewriter, loc, next); - - run.packed = loop.getResult(0); + run.packed = loop->results.front(); return run.packed; } @@ -2297,34 +2365,30 @@ FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); + FailureOr> produced = + cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); + if (failed(produced) || produced->size() != 1) + return failure(); - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value loopIndex = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); - if (failed(produced)) + Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, loopIndex, loc); + Value next = insertFragmentIntoWholeBatch(state, produced->front(), acc, outputOffset, loc); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) return failure(); - - if (produced->size() != 1) - return failure(); - - Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, loopIndex, loc); - Value next = insertFragmentIntoWholeBatch(state, produced->front(), acc, outputOffset, loc); - - scf::YieldOp::create(state.rewriter, loc, next); - return loop.getResult(0); + return loop->results.front(); } FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state, @@ -2362,32 +2426,30 @@ FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value channelId = createIndexedIndexValue(state, targetClass.op, run.channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, targetClass.op, run.sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, targetClass.op, run.targetCoreIds, index, loc); - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value index = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value channelId = createIndexedIndexValue(state, targetClass.op, run.channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, targetClass.op, run.sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, targetClass.op, run.targetCoreIds, index, loc); - - Value received = - SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - - Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, index, loc); - Value next = insertFragmentIntoWholeBatch(state, received, acc, outputOffset, loc); - - scf::YieldOp::create(state.rewriter, loc, next); - return loop.getResult(0); + Value received = + SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, index, loc); + Value next = insertFragmentIntoWholeBatch(state, received, acc, outputOffset, loc); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); } FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, @@ -2444,25 +2506,24 @@ FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); - - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value slotIndex = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value slotPacked = extractPackedSlotForIndex(state, targetClass.op, run.packed, *slotPackedType, slotIndex, loc); - Value outputOffset = createIndexedIndexValue(state, targetClass.op, slotRowOffsets, slotIndex, loc); - Value next = insertFragmentIntoWholeBatch(state, slotPacked, acc, outputOffset, loc); - - scf::YieldOp::create(state.rewriter, loc, next); - return loop.getResult(0); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value slotPacked = extractPackedSlotForIndex(state, targetClass.op, run.packed, *slotPackedType, slotIndex, loc); + Value outputOffset = createIndexedIndexValue(state, targetClass.op, slotRowOffsets, slotIndex, loc); + Value next = insertFragmentIntoWholeBatch(state, slotPacked, acc, outputOffset, loc); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); } LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, @@ -3055,10 +3116,11 @@ LogicalResult emitPackedRunFanout(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); - SmallVector receivePlans = - emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc); + auto receivePlans = emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc); + if (failed(receivePlans)) + return failure(); - for (const ScalarSourceReceivePlan& plan : receivePlans) { + for (const ScalarSourceReceivePlan& plan : *receivePlans) { MaterializedClass& targetClass = state.classes[plan.targetClass]; Value received = @@ -3190,39 +3252,36 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange(initValues)); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange(initValues), + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); + FailureOr> produced = cloneBatchBodyForLane( + state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex); + if (failed(produced)) + return failure(); - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value loopIndex = loop.getInductionVar(); - Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex); - if (failed(produced)) + yielded.reserve(produced->size()); + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = cast(output.getType()); + Value acc = iterArgs[outputIndex]; + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); + } + return success(); + }); + if (failed(loop)) return failure(); - SmallVector yielded; - yielded.reserve(produced->size()); - - for (auto [outputIndex, output] : llvm::enumerate(*produced)) { - auto fragmentType = cast(output.getType()); - Value acc = body->getArgument(1 + outputIndex); - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); - yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); - } - - scf::YieldOp::create(state.rewriter, loc, yielded); - SmallVector results; - results.reserve(loop.getNumResults()); - for (Value result : loop.getResults()) + results.reserve(loop->results.size()); + for (Value result : loop->results) results.push_back(result); return results; } @@ -3523,12 +3582,18 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, for ([[maybe_unused]] const MaterializationRunSlot& slot : run) { for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = + getCheckedCoreId(targetClass.op, + targetClass.isBatch ? targetClass.cpus[lane] : targetClass.cpus.front(), + "batch run target core id"); + if (failed(checkedTargetCpu)) + return failure(); plan.channelIds.push_back(state.nextChannelId++); - plan.sourceCoreIds.push_back(static_cast(sourceCpu)); - - int32_t targetCpu = targetClass.isBatch ? static_cast(targetClass.cpus[lane]) - : static_cast(targetClass.cpus.front()); - plan.targetCoreIds.push_back(targetCpu); + plan.sourceCoreIds.push_back(*checkedSourceCpu); + plan.targetCoreIds.push_back(*checkedTargetCpu); } } @@ -3666,36 +3731,42 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToStart(loop.getBody()); + FailureOr> produced = cloneBatchBodyForLane( + state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex); + if (failed(produced)) + return failure(); - Value slotIndex = loop.getInductionVar(); - Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + for (const BatchRunSendPlan& plan : sendPlans) { + auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); + if (resultIt == group.resultIndices.end()) + return failure(); - FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex); - if (failed(produced)) + size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); + appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); + } + + for (const BatchRunSendPlan& plan : sendPlans) { + if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) + return failure(); + + if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) + return failure(); + } + return success(); + }); + if (failed(loop)) return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); - if (resultIt == group.resultIndices.end()) - return targetClass.op->emitError("internal error: missing compacted batch run result"); - - size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); - appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); - } - - for (const BatchRunSendPlan& plan : sendPlans) { - if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) - return sourceBatch.emitOpError("failed to recover per-lane output type for compacted batch run"); - - if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) - return failure(); - } } return success(); @@ -3754,15 +3825,15 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns return success(); } -Value createReceiveConcatLoop(MaterializerState& state, - Operation* anchor, - Operation* insertionPoint, - RankedTensorType concatType, - RankedTensorType fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { +FailureOr createReceiveConcatLoop(MaterializerState& state, + Operation* anchor, + Operation* insertionPoint, + RankedTensorType concatType, + RankedTensorType fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); assert(!channelIds.empty() && "expected at least one receive"); @@ -3774,31 +3845,30 @@ Value createReceiveConcatLoop(MaterializerState& state, state.rewriter.setInsertionPoint(insertionPoint); Value init = tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value index = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); - - Value received = - SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput(); - - Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); - Value next = createDim0InsertSlice(state, loc, received, acc, firstOffset); - scf::YieldOp::create(state.rewriter, loc, next); - - return loop.getResult(0); + Value received = + SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); + Value next = createDim0InsertSlice(state, loc, received, acc, firstOffset); + yielded.push_back(next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); } void replaceHostUses(MaterializerState& state) { @@ -3832,7 +3902,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch if (failed(collectHostOutputs(state))) return failure(); - createEmptyMaterializedOps(state); + if (failed(createEmptyMaterializedOps(state))) + return failure(); if (failed(collectProducerDestinations(state))) return failure(); if (failed(collectProjectedTransfers(state))) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 4cb3682..997f6a6 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -37,6 +37,7 @@ #include "Scheduling/MergeSchedulingAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -128,8 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) { } static std::optional getComputeCoreId(SpatCompute compute) { - if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - return static_cast(coreIdAttr.getInt()); + if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { + auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id"); + if (failed(checkedCoreId)) + return std::nullopt; + return *checkedCoreId; + } return std::nullopt; } diff --git a/src/PIM/Pass/CMakeLists.txt b/src/PIM/Pass/CMakeLists.txt index 07080c8..026edd3 100644 --- a/src/PIM/Pass/CMakeLists.txt +++ b/src/PIM/Pass/CMakeLists.txt @@ -1,11 +1,5 @@ add_pim_library(OMPimPasses MessagePass.cpp - PimCodegen/HostConstantFolding/Common.cpp - PimCodegen/HostConstantFolding/Patterns/Constant.cpp - PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp - PimCodegen/HostConstantFolding/Patterns/Subview.cpp - PimCodegen/MaterializeHostConstantsPass.cpp - PimCodegen/VerificationPass.cpp PimCodegen/EmitPimCodePass.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index 86a1e1e..fe817ba 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -15,7 +15,7 @@ std::unique_ptr createSpatialToPimPass(); std::unique_ptr createPimBufferizationPass(); -std::unique_ptr createPimStaticMemoryCoalescingPass(); +std::unique_ptr createPimMemoryCoalescingPass(); std::unique_ptr createMergeComputeNodesPass(); diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 76836a7..98ab342 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -75,7 +75,7 @@ void PimAccelerator::registerPasses(int optLevel) const { registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPimPass); registerPass(createPimBufferizationPass); - registerPass(createPimStaticMemoryCoalescingPass); + registerPass(createPimMemoryCoalescingPass); registerPass(createMergeComputeNodesPass); registerPass(createPimHostConstantFoldingPass); registerPass(createPimMaterializeHostConstantsPass);