add shared loop creation helpers
Validate Operations / validate-operations (push) Waiting to run

add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
This commit is contained in:
NiccoloN
2026-06-01 16:49:06 +02:00
parent 356be6ccc2
commit 636310d0cb
55 changed files with 2007 additions and 1103 deletions
+4 -1
View File
@@ -121,6 +121,9 @@ add_pim_library(OMPIMAccel
OMSpatialToPim OMSpatialToPim
OMPimCommon OMPimCommon
OMPimBufferization OMPimBufferization
OMPimStaticMemoryCoalescing OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification
MLIRTensorInferTypeOpInterfaceImpl MLIRTensorInferTypeOpInterfaceImpl
) )
+3
View File
@@ -5,9 +5,11 @@ add_pim_library(OMPimCommon
IR/ConstantUtils.cpp IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp IR/EntryPointUtils.cpp
IR/LoopUtils.cpp
IR/ShapeUtils.cpp IR/ShapeUtils.cpp
IR/SubviewUtils.cpp IR/SubviewUtils.cpp
IR/WeightUtils.cpp IR/WeightUtils.cpp
Support/CheckedArithmetic.cpp
Support/DebugDump.cpp Support/DebugDump.cpp
Support/Diagnostics.cpp Support/Diagnostics.cpp
Support/FileSystemUtils.cpp Support/FileSystemUtils.cpp
@@ -20,6 +22,7 @@ add_pim_library(OMPimCommon
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect MLIRLinalgDialect
MLIRSCFDialect
onnx onnx
SpatialOps SpatialOps
PimOps PimOps
+96
View File
@@ -0,0 +1,96 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/Support/MathExtras.h"
#include <optional>
#include "ConstantUtils.hpp"
#include "LoopUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static std::optional<int64_t> getStaticTripCount(Value lowerBound, Value upperBound, Value step) {
auto lower = matchConstantIndexValue(lowerBound);
auto upper = matchConstantIndexValue(upperBound);
auto stepValue = matchConstantIndexValue(step);
if (!lower || !upper || !stepValue)
return std::nullopt;
if (*stepValue <= 0)
return std::nullopt;
if (*upper <= *lower)
return int64_t {0};
return llvm::divideCeil(*upper - *lower, *stepValue);
}
} // namespace
static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef<Value> yieldedValues) {
if (yieldedValues.size() == initArgs.size())
return success();
emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size()
<< " iter args";
return failure();
}
FailureOr<NormalizedLoopResult> buildNormalizedScfFor(OpBuilder& builder,
Location loc,
Value lowerBound,
Value upperBound,
Value step,
ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder) {
NormalizedLoopResult result;
if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) {
emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue;
return failure();
}
if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) {
if (*tripCount == 0) {
llvm::append_range(result.results, initArgs);
return result;
}
if (*tripCount == 1) {
result.inductionVar = lowerBound;
if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results)))
return failure();
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results)))
return failure();
return result;
}
}
result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs);
result.inductionVar = result.loop.getInductionVar();
{
OpBuilder::InsertionGuard guard(builder);
Block* body = result.loop.getBody();
if (!body->empty())
if (auto yieldOp = dyn_cast<scf::YieldOp>(body->back()))
yieldOp->erase();
builder.setInsertionPointToEnd(body);
ValueRange iterArgs = result.loop.getRegionIterArgs();
if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) {
result.loop.erase();
return failure();
}
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) {
result.loop.erase();
return failure();
}
scf::YieldOp::create(builder, loc, result.results);
}
builder.setInsertionPointAfter(result.loop);
result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end());
return result;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
struct NormalizedLoopResult {
mlir::Value inductionVar;
llvm::SmallVector<mlir::Value, 4> results;
mlir::scf::ForOp loop;
bool wasInlined() const { return !loop; }
};
using NormalizedLoopBodyBuilder = llvm::function_ref<mlir::LogicalResult(
mlir::OpBuilder&, mlir::Location, mlir::Value, mlir::ValueRange, llvm::SmallVectorImpl<mlir::Value>&)>;
mlir::FailureOr<NormalizedLoopResult> buildNormalizedScfFor(mlir::OpBuilder& builder,
mlir::Location loc,
mlir::Value lowerBound,
mlir::Value upperBound,
mlir::Value step,
mlir::ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder);
} // namespace onnx_mlir
@@ -0,0 +1,222 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) {
llvm::errs() << "PIM " << fieldName << " " << message << "\n";
}
template <typename To, typename From>
static FailureOr<To> checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCastAtLocation requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
FailureOr<UInt> checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>,
"checkedMulAtLocation requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(loc, fieldName, "multiplication overflow");
return failure();
}
return lhs * rhs;
}
} // namespace
InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) {
assert(anchor && "expected arithmetic diagnostics to have an anchor op");
return anchor->emitOpError() << fieldName << " " << message;
}
InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) {
return emitError(loc) << "PIM " << fieldName << " " << message;
}
FailureOr<int32_t> checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<int32_t> checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<uint8_t> checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<uint8_t>(value, anchor, fieldName);
}
FailureOr<size_t> checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<size_t>(value, anchor, fieldName);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) {
assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor");
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMul(elements, static_cast<uint64_t>(dim), anchor, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMul(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) {
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(loc, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMulAtLocation(elements, static_cast<uint64_t>(dim), loc, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMulAtLocation(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName);
}
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max()) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<uint8_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<uint8_t>(value);
}
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < 0) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<size_t>(value);
}
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (rhs > std::numeric_limits<size_t>::max() - lhs) {
emitCrashMessage(fieldName, "addition overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs + rhs;
}
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (lhs != 0 && rhs > std::numeric_limits<size_t>::max() / lhs) {
emitCrashMessage(fieldName, "multiplication overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs * rhs;
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,107 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
template <typename To, typename From>
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
return mlir::failure();
}
return lhs + rhs;
}
template <typename UInt>
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
return mlir::failure();
}
return lhs * rhs;
}
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
} // namespace onnx_mlir::pim
+4 -1
View File
@@ -28,7 +28,10 @@ add_pim_library(OMPimCompilerUtils
OMPimCompilerOptions OMPimCompilerOptions
OMPimCommon OMPimCommon
OMPimBufferization OMPimBufferization
OMPimStaticMemoryCoalescing OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification
OMPimPasses OMPimPasses
OMONNXToSpatial OMONNXToSpatial
OMSpatialToPim OMSpatialToPim
+4 -9
View File
@@ -6,8 +6,8 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <array> #include <array>
#include <cassert>
#include <limits> #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
namespace onnx_mlir::pim_binary { namespace onnx_mlir::pim_binary {
@@ -95,15 +95,10 @@ inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecor
writeInt32LE(os, record.generic3); writeInt32LE(os, record.generic3);
} }
inline int32_t toI32(int64_t value) { inline int32_t toI32(int64_t value) { return onnx_mlir::pim::checkedI32OrCrash(value, "binary field"); }
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
&& "PIM binary field out of int32 range");
return static_cast<int32_t>(value);
}
inline uint8_t toU8(int64_t value) { inline uint8_t toU8(int64_t value) {
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range"); return onnx_mlir::pim::checkedU8OrCrash(static_cast<uint64_t>(value), "binary field");
return static_cast<uint8_t>(value);
} }
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) { inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
+44 -28
View File
@@ -25,12 +25,14 @@
#include <cassert> #include <cassert>
#include <cstdint> #include <cstdint>
#include <fstream> #include <fstream>
#include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include "Common/IR/CompactAsmUtils.hpp" #include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp" #include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
@@ -71,12 +73,23 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
return {value, getLaneForMemoryValue(value, lane)}; return {value, getLaneForMemoryValue(value, lane)};
} }
static int32_t getVectorByteSizeOrCrash(ShapedType type) {
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size");
if (failed(byteSize))
llvm_unreachable("Failed to compute checked vector byte size");
return pim::checkedI32OrCrash(*byteSize, "vector byte size");
}
} // namespace } // namespace
MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> lane) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> lane) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = getShapedTypeSizeInBytes(type); auto checkedAllocSize =
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size");
if (failed(checkedAllocSize))
llvm_unreachable("Failed to compute checked allocation byte size");
size_t allocSize = static_cast<size_t>(*checkedAllocSize);
MemEntry memEntry = {0, allocSize}; MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first; return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first;
} }
@@ -272,7 +285,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
llvm_unreachable("Missing mem entry"); llvm_unreachable("Missing mem entry");
} }
return iter->second.address + resolvedAddress->byteOffset; size_t byteOffset = pim::checkedSizeOrCrash(resolvedAddress->byteOffset, "resolved PIM byte offset");
return pim::checkedAddOrCrash(iter->second.address, byteOffset, "resolved PIM address");
} }
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value, llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
@@ -291,8 +305,12 @@ llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); } void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) { void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back( reportEntries.push_back({MemoryReportEntry::Kind::Core,
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca}); coreId,
{pim::checkedI32OrCrash(coreId, "memory report core id")},
row,
row.numAlloca,
row.sizeAlloca});
} }
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
@@ -402,24 +420,24 @@ void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t i
pim_binary::InstructionRecord instruction; pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::sldi; instruction.opcode = pim_binary::Opcode::sldi;
instruction.rd = static_cast<uint8_t>(registerNumber); instruction.rd = static_cast<uint8_t>(registerNumber);
instruction.r2OrImm = static_cast<int32_t>(immediate); instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
emitInstruction(instruction); emitInstruction(instruction);
} }
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const { void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
} }
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const { void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset); genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
} }
void PimCodeGen::setupRdRs1Rs2( void PimCodeGen::setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const { size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset); genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset); genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset); genSetRegisterImmediateUnsigned(2, pim::checkedAddOrCrash(rs2Address, rs2Offset, "rs2 address"));
} }
void PimCodeGen::emitMemCopyOp(StringRef opName, void PimCodeGen::emitMemCopyOp(StringRef opName,
@@ -437,8 +455,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic1 = 0; instruction.generic1 = 0;
instruction.generic2 = 0; instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size); instruction.generic3 = pim::checkedI32OrCrash(size, sizeFieldName);
(void) sizeFieldName;
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -448,10 +465,10 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
pim_binary::InstructionRecord instruction; pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::opcodeFromString(opName); instruction.opcode = pim_binary::opcodeFromString(opName);
instruction.rd = 0; instruction.rd = 0;
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId)); instruction.r2OrImm = pim::checkedI32OrCrash(remapCoreId(coreId), "communication core id");
instruction.generic1 = 0; instruction.generic1 = 0;
instruction.generic2 = 0; instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size); instruction.generic3 = pim::checkedI32OrCrash(size, "communication byte size");
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -464,7 +481,7 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 8; instruction.r2OrImm = 8;
instruction.generic1 = 0; instruction.generic1 = 0;
instruction.generic2 = static_cast<int32_t>(groupId); instruction.generic2 = pim::checkedI32OrCrash(groupId, "mvm group id");
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -578,7 +595,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvaddOp.getLhs().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -593,7 +610,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvsubOp.getLhs().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -608,7 +625,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmulOp.getLhs().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -623,7 +640,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmaxOp.getLhs().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -638,7 +655,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvdmulOp.getLhs().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -653,7 +670,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 1; instruction.r2OrImm = 1;
instruction.generic1 = 1; instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vavgOp.getInput().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -666,7 +683,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vrelu; instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vreluOp.getInput().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -679,7 +696,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vtanh; instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vtanhOp.getInput().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -692,7 +709,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vsigm; instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType()))); instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsigmOp.getInput().getType()));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -705,8 +722,7 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
instruction.opcode = pim_binary::Opcode::vsoftmax; instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsoftmaxOp.getInput().getType()));
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -1370,7 +1386,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup)) if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
return err; return err;
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId)); reportedCoreIds.push_back(pim::checkedI32OrCrash(job.emittedCoreId, "batch report core id"));
if (!batchPerCoreRow) if (!batchPerCoreRow)
batchPerCoreRow = result.reportRow; batchPerCoreRow = result.reportRow;
batchRow = addMemoryReportRows(batchRow, result.reportRow); batchRow = addMemoryReportRows(batchRow, result.reportRow);
+1 -1
View File
@@ -40,7 +40,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimBufferized) { if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass()); pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
pm.addPass(createMessagePass("Pim bufferized")); pm.addPass(createMessagePass("Pim bufferized"));
} }
@@ -48,6 +47,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
pm.addPass(createPimHostConstantFoldingPass()); pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded")); pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass()); pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimMemoryCoalescingPass());
pm.addPass(createPimVerificationPass()); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified")); pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimCodePass()); pm.addPass(createEmitPimCodePass());
@@ -12,6 +12,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -180,8 +181,11 @@ auto createSpatComputeBatch(RewriterT& rewriter,
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max()) if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure()); return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create( auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs); if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()}; mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc}; mlir::SmallVector<mlir::Location> blockArgLocs {loc};
@@ -8,6 +8,7 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -305,20 +306,24 @@ static Value createIm2colRowComputes(Value x,
auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); auto im2colLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(im2colLoop.getBody()); rewriter,
loc,
c0,
cNumPatches,
c1,
ValueRange {im2colInit},
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value im2colAcc = iterArgs.front();
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
Value patchIndex = im2colLoop.getInductionVar(); SmallVector<OpFoldResult> offsets = {
Value im2colAcc = im2colLoop.getRegionIterArgs().front(); batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn), rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight), rewriter.getIndexAttr(wHeight),
@@ -328,10 +333,11 @@ static Value createIm2colRowComputes(Value x,
rewriter.getIndexAttr(dilationHeight), rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)}; rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); Value patch =
tensor::ExtractSliceOp::create(rewriter, nestedLoc, patchType, paddedInput, offsets, sizes, strides);
Value row = tensor::CollapseShapeOp::create(rewriter, Value row = tensor::CollapseShapeOp::create(rewriter,
loc, nestedLoc,
im2colRowType, im2colRowType,
patch, patch,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
@@ -343,20 +349,24 @@ static Value createIm2colRowComputes(Value x,
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col = Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); tensor::InsertSliceOp::create(rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col); yielded.push_back(updatedIm2col);
return success();
rewriter.setInsertionPointAfter(im2colLoop); });
Value im2col = im2colLoop.getResult(0); if (failed(im2colLoop))
return failure();
Value im2col = im2colLoop->results.front();
Value gemmInputRows = im2col; Value gemmInputRows = im2col;
if (packFactor != 1) if (packFactor != 1)
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc); gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
return success();
}); });
return im2colComputeOp.getResult(0); assert(succeeded(im2colComputeOp) && "Conv im2col compute construction must succeed");
return im2colComputeOp->getResult(0);
} }
static Value createCollectedConvOutput(ValueRange gemmRows, static Value createCollectedConvOutput(ValueRange gemmRows,
@@ -15,6 +15,7 @@
#include "Common/IR/ConstantUtils.hpp" #include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -247,7 +248,7 @@ static Value createPaddedInputCompute(Value input,
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static spatial::SpatComputeBatch createVmmBatch(Value a, static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
Value b, Value b,
RankedTensorType aType, RankedTensorType aType,
RankedTensorType paddedBType, RankedTensorType paddedBType,
@@ -294,7 +295,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
createParallelInsertSliceIntoBatchOutput( createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides); rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
}); });
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed"); if (failed(batchOp))
return failure();
return *batchOp; return *batchOp;
} }
@@ -416,7 +418,7 @@ static Value createBroadcastedBiasScalar(Value bias,
return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult(); return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult();
} }
static spatial::SpatComputeBatch createVvdmulBatch(Value a, static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
Value b, Value b,
RankedTensorType aType, RankedTensorType aType,
RankedTensorType bType, RankedTensorType bType,
@@ -454,11 +456,12 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
createParallelInsertSliceIntoBatchOutput( createParallelInsertSliceIntoBatchOutput(
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides); rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
}); });
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed"); if (failed(batchOp))
return failure();
return *batchOp; return *batchOp;
} }
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, static FailureOr<spatial::SpatCompute> createDynamicGemmOutputCompute(Value scalarPieces,
Value bias, Value bias,
RankedTensorType scalarPiecesType, RankedTensorType scalarPiecesType,
RankedTensorType biasType, RankedTensorType biasType,
@@ -473,7 +476,7 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
if (bias) if (bias)
inputs.push_back(bias); inputs.push_back(bias);
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) { return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
Value pieces = blockArgs[0]; Value pieces = blockArgs[0];
Value biasArg = bias ? blockArgs[1] : Value(); Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -481,40 +484,50 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); auto loop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(loop.getBody()); rewriter,
loc,
Value lane = loop.getInductionVar(); c0,
Value outputAcc = loop.getRegionIterArgs().front(); cLaneCount,
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); c1,
ValueRange {outputInit},
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value outputAcc = iterArgs.front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, nestedLoc);
Value column = Value column =
onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); onnx_mlir::affineModConst(rewriter, nestedLoc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = Value scalar = tensor::ExtractSliceOp::create(
tensor::ExtractSliceOp::create(rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides) rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
.getResult(); .getResult();
if (alpha != 1.0f) { if (alpha != 1.0f) {
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, loc); Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, nestedLoc);
scalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, scalar, alphaTensor).getResult(); scalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, scalar, alphaTensor).getResult();
} }
if (biasArg) { if (biasArg) {
Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, loc); Value biasScalar =
createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, nestedLoc);
if (beta != 1.0f) { if (beta != 1.0f) {
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, loc); Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, nestedLoc);
biasScalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, biasScalar, betaTensor).getResult(); biasScalar =
spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, biasScalar, betaTensor).getResult();
} }
scalar = spatial::SpatVAddOp::create(rewriter, loc, scalarType, scalar, biasScalar).getResult(); scalar = spatial::SpatVAddOp::create(rewriter, nestedLoc, scalarType, scalar, biasScalar).getResult();
} }
SmallVector<OpFoldResult> outputOffsets {row, column}; SmallVector<OpFoldResult> outputOffsets {row, column};
Value outputNext = Value outputNext =
tensor::InsertSliceOp::create(rewriter, loc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides) tensor::InsertSliceOp::create(rewriter, nestedLoc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
.getResult(); .getResult();
scf::YieldOp::create(rewriter, loc, outputNext); yielded.push_back(outputNext);
return success();
});
if (failed(loop))
return failure();
rewriter.setInsertionPointAfter(loop); spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0)); return success();
}); });
} }
@@ -579,7 +592,7 @@ static Value reducePartialPiecesForHSlice(Value partialPiecesArg,
return activePieces.front(); return activePieces.front();
} }
static spatial::SpatCompute createReductionCompute(Value partialPieces, static FailureOr<spatial::SpatCompute> createReductionCompute(Value partialPieces,
Value bias, Value bias,
RankedTensorType partialPiecesType, RankedTensorType partialPiecesType,
RankedTensorType outType, RankedTensorType outType,
@@ -591,7 +604,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
if (bias) if (bias)
inputs.push_back(bias); inputs.push_back(bias);
auto computeOp = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) { auto computeOp =
createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
Value partialPiecesArg = blockArgs[0]; Value partialPiecesArg = blockArgs[0];
Value biasArg = bias ? blockArgs[1] : Value(); Value biasArg = bias ? blockArgs[1] : Value();
if (biasArg && cast<RankedTensorType>(biasArg.getType()) != paddedOutType) if (biasArg && cast<RankedTensorType>(biasArg.getType()) != paddedOutType)
@@ -636,15 +650,20 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices = Value cOutHSlices =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); auto hLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(hLoop.getBody()); rewriter,
loc,
Value hSlice = hLoop.getInductionVar(); c0,
Value outputAcc = hLoop.getRegionIterArgs().front(); cOutHSlices,
scf::YieldOp::create(rewriter, loc, buildOutputSlice(outputAcc, hSlice)); c1,
ValueRange {outputInit},
rewriter.setInsertionPointAfter(hLoop); [&](OpBuilder&, Location, Value hSlice, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
paddedOutput = hLoop.getResult(0); yielded.push_back(buildOutputSlice(iterArgs.front(), hSlice));
return success();
});
if (failed(hLoop))
return failure();
paddedOutput = hLoop->results.front();
} }
Value result = paddedOutput; Value result = paddedOutput;
@@ -657,6 +676,7 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
.getResult(); .getResult();
} }
spatial::SpatYieldOp::create(rewriter, loc, result); spatial::SpatYieldOp::create(rewriter, loc, result);
return success();
}); });
return computeOp; return computeOp;
@@ -755,9 +775,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType()); auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp = auto batchOp =
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc); createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
if (failed(batchOp))
return failure();
auto outputCompute = createDynamicGemmOutputCompute( auto outputCompute = createDynamicGemmOutputCompute(
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc); batchOp->getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
rewriter.replaceOp(gemmOp, outputCompute.getResults()); if (failed(outputCompute))
return failure();
rewriter.replaceOp(gemmOp, outputCompute->getResults());
return success(); return success();
} }
@@ -832,10 +856,14 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType()); RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
auto batchOp = auto batchOp =
createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc); createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
if (failed(batchOp))
return failure();
auto reductionCompute = createReductionCompute( auto reductionCompute = createReductionCompute(
batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc); batchOp->getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
if (failed(reductionCompute))
return failure();
rewriter.replaceOp(gemmOp, reductionCompute.getResults()); rewriter.replaceOp(gemmOp, reductionCompute->getResults());
return success(); return success();
} }
@@ -8,6 +8,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -281,7 +282,7 @@ static Value getBatchLaneIndex(
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp()); rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
} }
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
Value b, Value b,
RankedTensorType aType, RankedTensorType aType,
int64_t aBatchCount, int64_t aBatchCount,
@@ -331,7 +332,8 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
createParallelInsertSliceIntoBatchOutput( createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
}); });
assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed"); if (failed(batchOp))
return failure();
return *batchOp; return *batchOp;
} }
@@ -422,7 +424,7 @@ static Value extractDynamicBatchedRowVector(Value matrix,
}); });
} }
static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a, static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
int64_t aBatchCount, int64_t aBatchCount,
Value b, Value b,
int64_t bBatchCount, int64_t bBatchCount,
@@ -466,11 +468,12 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
createParallelInsertSliceIntoBatchOutput( createParallelInsertSliceIntoBatchOutput(
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2)); rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
}); });
assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed"); if (failed(batchOp))
return failure();
return *batchOp; return *batchOp;
} }
static Value createBatchedDynamicOutputCompute(Value scalarPieces, static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
RankedTensorType scalarPiecesType, RankedTensorType scalarPiecesType,
RankedTensorType outType, RankedTensorType outType,
PatternRewriter& rewriter, PatternRewriter& rewriter,
@@ -481,29 +484,33 @@ static Value createBatchedDynamicOutputCompute(Value scalarPieces,
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType()); auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
auto computeOp = auto computeOp = createSpatCompute<1>(
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) { rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) -> LogicalResult {
Value outputInit = Value outputInit =
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); auto loop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(loop.getBody()); rewriter,
loc,
Value lane = loop.getInductionVar(); c0,
Value outputAcc = loop.getRegionIterArgs().front(); cLaneCount,
c1,
ValueRange {outputInit},
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value outputAcc = iterArgs.front();
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp); Value batch = affineFloorDivConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp); Value batchLane = affineModConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp); Value row = affineFloorDivConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp); Value column = affineModConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = tensor::ExtractSliceOp::create( Value scalar = tensor::ExtractSliceOp::create(
rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2)); rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
Value expanded = tensor::ExpandShapeOp::create(rewriter, Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc, nestedLoc,
outputScalarType, outputScalarType,
scalar, scalar,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
@@ -511,19 +518,23 @@ static Value createBatchedDynamicOutputCompute(Value scalarPieces,
{1, 2} {1, 2}
}); });
SmallVector<OpFoldResult> outputOffsets {batch, row, column}; SmallVector<OpFoldResult> outputOffsets {batch, row, column};
SmallVector<OpFoldResult> outputSizes { SmallVector<OpFoldResult> outputSizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
scf::YieldOp::create( Value next =
rewriter,
loc,
tensor::InsertSliceOp::create( tensor::InsertSliceOp::create(
rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) rewriter, nestedLoc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
.getResult()); .getResult();
yielded.push_back(next);
rewriter.setInsertionPointAfter(loop); return success();
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
}); });
return computeOp.getResult(0); if (failed(loop))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
return success();
});
if (failed(computeOp))
return failure();
return computeOp->getResult(0);
} }
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) { static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
@@ -587,7 +598,7 @@ static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
return activePieces.front(); return activePieces.front();
} }
static Value createBatchedReductionCompute(Value partialPieces, static FailureOr<Value> createBatchedReductionCompute(Value partialPieces,
RankedTensorType partialPiecesType, RankedTensorType partialPiecesType,
RankedTensorType outType, RankedTensorType outType,
RankedTensorType paddedOutType, RankedTensorType paddedOutType,
@@ -596,7 +607,7 @@ static Value createBatchedReductionCompute(Value partialPieces,
PatternRewriter& rewriter, PatternRewriter& rewriter,
Location loc) { Location loc) {
auto computeOp = createSpatCompute<1>( auto computeOp = createSpatCompute<1>(
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) { rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) -> LogicalResult {
const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutRows = outType.getDimSize(1);
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue()); const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())}, auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
@@ -612,43 +623,55 @@ static Value createBatchedReductionCompute(Value partialPieces,
Value cNumOutHSlices = Value cNumOutHSlices =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit}); auto batchLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(batchLoop.getBody()); rewriter,
Value batch = batchLoop.getInductionVar();
Value batchAcc = batchLoop.getRegionIterArgs().front();
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc});
rewriter.setInsertionPointToStart(hLoop.getBody());
Value hSlice = hLoop.getInductionVar();
Value outputAcc = hLoop.getRegionIterArgs().front();
Value reduced = reduceBatchedPartialPiecesForHSlice(
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc);
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
c0,
cNumBatches,
c1,
ValueRange {outputInit},
[&](
OpBuilder&, Location batchLoc, Value batch, ValueRange batchIterArgs, SmallVectorImpl<Value>& batchYielded) {
auto hLoop = buildNormalizedScfFor(
rewriter,
batchLoc,
c0,
cNumOutHSlices,
c1,
ValueRange {batchIterArgs.front()},
[&](OpBuilder&, Location hLoc, Value hSlice, ValueRange hIterArgs, SmallVectorImpl<Value>& hYielded) {
Value outputAcc = hIterArgs.front();
Value reduced = reduceBatchedPartialPiecesForHSlice(
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, hLoc);
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
hLoc,
outputSliceType, outputSliceType,
reduced, reduced,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0, 1}, {0, 1},
{2} {2}
}); });
Value hOffset = Value hOffset = affineMulConst(
affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); rewriter, hLoc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
SmallVector<OpFoldResult> outputSizes { SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; rewriter.getIndexAttr(numOutRows),
scf::YieldOp::create( rewriter.getIndexAttr(crossbarSize.getValue())};
rewriter, Value next =
loc,
tensor::InsertSliceOp::create( tensor::InsertSliceOp::create(
rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) rewriter, hLoc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
.getResult()); .getResult();
hYielded.push_back(next);
rewriter.setInsertionPointAfter(hLoop); return success();
scf::YieldOp::create(rewriter, loc, hLoop.getResult(0)); });
if (failed(hLoop))
rewriter.setInsertionPointAfter(batchLoop); return failure();
Value paddedOutput = batchLoop.getResult(0); batchYielded.push_back(hLoop->results.front());
return success();
});
if (failed(batchLoop))
return failure();
Value paddedOutput = batchLoop->results.front();
Value result = paddedOutput; Value result = paddedOutput;
if (paddedOutType != outType) { if (paddedOutType != outType) {
SmallVector<OpFoldResult> outputOffsets { SmallVector<OpFoldResult> outputOffsets {
@@ -660,8 +683,11 @@ static Value createBatchedReductionCompute(Value partialPieces,
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)); rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
} }
spatial::SpatYieldOp::create(rewriter, loc, result); spatial::SpatYieldOp::create(rewriter, loc, result);
return success();
}); });
return computeOp.getResult(0); if (failed(computeOp))
return failure();
return computeOp->getResult(0);
} }
struct MatMulShapeInfo { struct MatMulShapeInfo {
@@ -841,7 +867,9 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
numOutHSlices, numOutHSlices,
rewriter, rewriter,
loc); loc);
Value result = createBatchedReductionCompute(batchOp.getResult(0), if (failed(batchOp))
return failure();
auto result = createBatchedReductionCompute(batchOp->getResult(0),
partialPiecesType, partialPiecesType,
directOutType, directOutType,
paddedOutType, paddedOutType,
@@ -849,14 +877,17 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
numKSlices, numKSlices,
rewriter, rewriter,
loc); loc);
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm) if (useTransposedForm)
result = transposeBatchedOutput( finalResult = transposeBatchedOutput(
result, finalResult,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter, rewriter,
loc); loc);
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, finalResult);
return success(); return success();
} }
} }
@@ -873,16 +904,21 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
false, false,
rewriter, rewriter,
loc); loc);
Value result = if (failed(batchOp))
createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc); return failure();
auto result =
createBatchedDynamicOutputCompute(batchOp->getResult(0), scalarPiecesType, directOutType, rewriter, loc);
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm) if (useTransposedForm)
result = transposeBatchedOutput( finalResult = transposeBatchedOutput(
result, finalResult,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter, rewriter,
loc); loc);
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, finalResult);
return success(); return success();
} }
}; };
@@ -12,6 +12,7 @@
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -275,38 +276,46 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); auto outputLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(outputLoop.getBody()); rewriter,
loc,
Value outputPatchIndex = outputLoop.getInductionVar(); c0,
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front(); cOutputPatchCount,
c1,
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); ValueRange {pooledOutputInit},
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); [&](OpBuilder&,
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); Location nestedLoc,
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); Value outputPatchIndex,
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); ValueRange iterArgs,
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); SmallVectorImpl<Value>& yielded) {
Value pooledOutputAcc = iterArgs.front();
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch);
Value batchPatchIndex =
arith::RemUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth);
Value windowBaseH = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
Value windowBaseW = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
Value updatedOutput = pooledOutputAcc; Value updatedOutput = pooledOutputAcc;
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize); const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value reducedWindow = Value reducedWindow =
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>); createPoolFillTensor(rewriter, nestedLoc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value paddedInH = windowBaseH; Value paddedInH = windowBaseH;
if (kernelH * dilationHeight != 0) { if (kernelH * dilationHeight != 0) {
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight); Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); paddedInH = arith::AddIOp::create(rewriter, nestedLoc, paddedInH, kernelHOffset);
} }
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
Value paddedInW = windowBaseW; Value paddedInW = windowBaseW;
if (kernelW * dilationWidth != 0) { if (kernelW * dilationWidth != 0) {
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth); Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); paddedInW = arith::AddIOp::create(rewriter, nestedLoc, paddedInW, kernelWOffset);
} }
SmallVector<OpFoldResult> offsets = { SmallVector<OpFoldResult> offsets = {
@@ -315,28 +324,34 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = { SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue = Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); tensor::ExtractSliceOp::create(rewriter, nestedLoc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeTileTensor(rewriter, loc, windowValue); windowValue = materializeTileTensor(rewriter, nestedLoc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue); reducedWindow = ReduceOp::create(rewriter, nestedLoc, tileType, reducedWindow, windowValue);
} }
} }
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) { if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
SmallVector<OpFoldResult> scaleOffsets = { SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; rewriter.getIndexAttr(channelTile * xbarSize),
outHeightIndex,
outWidthIndex};
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1), SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = { SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create( Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); rewriter, nestedLoc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice); scaleSlice = materializeTileTensor(rewriter, nestedLoc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); reducedWindow = spatial::SpatVMulOp::create(rewriter, nestedLoc, tileType, reducedWindow, scaleSlice);
} }
SmallVector<OpFoldResult> outputOffsets = { SmallVector<OpFoldResult> outputOffsets = {
@@ -348,13 +363,15 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
SmallVector<OpFoldResult> outputStrides = { SmallVector<OpFoldResult> outputStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create( updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides); rewriter, nestedLoc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
} }
yielded.push_back(updatedOutput);
return success();
});
if (failed(outputLoop))
return failure();
scf::YieldOp::create(rewriter, loc, updatedOutput); spatial::SpatYieldOp::create(rewriter, loc, outputLoop->results.front());
rewriter.setInsertionPointAfter(outputLoop);
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -42,7 +43,7 @@ static Value buildLoopSoftmaxSlice(Value input,
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides); return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
} }
static Value buildLoopSoftmaxNest(Value input, static FailureOr<Value> buildLoopSoftmaxNest(Value input,
Value accumulator, Value accumulator,
RankedTensorType inputType, RankedTensorType inputType,
int64_t axis, int64_t axis,
@@ -57,38 +58,50 @@ static Value buildLoopSoftmaxNest(Value input,
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis)); Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator}); auto loop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(loop.getBody()); rewriter,
loc,
Value loopIndex = loop.getInductionVar(); c0,
Value loopAccumulator = loop.getRegionIterArgs().front(); cUpper,
c1,
ValueRange {accumulator},
[&](OpBuilder& builder, Location nestedLoc, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
outerIndices.push_back(loopIndex); outerIndices.push_back(loopIndex);
Value updatedAccumulator = auto updatedAccumulator =
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc); buildLoopSoftmaxNest(input, iterArgs.front(), inputType, axis + 1, outerIndices, rewriter, nestedLoc);
outerIndices.pop_back(); outerIndices.pop_back();
if (failed(updatedAccumulator))
scf::YieldOp::create(rewriter, loc, updatedAccumulator); return failure();
rewriter.setInsertionPointAfter(loop); yielded.push_back(*updatedAccumulator);
return loop.getResult(0); return success();
});
if (failed(loop))
return failure();
return loop->results.front();
} }
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { static FailureOr<Value> createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = auto computeOp = createSpatCompute<numInputs>(
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) -> LogicalResult {
if (inputType.getRank() == 1) { if (inputType.getRank() == 1) {
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult(); Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
spatial::SpatYieldOp::create(rewriter, loc, softmax); spatial::SpatYieldOp::create(rewriter, loc, softmax);
return; return success();
} }
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
SmallVector<Value> outerIndices; SmallVector<Value> outerIndices;
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc); auto result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, result); if (failed(result))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *result);
return success();
}); });
return computeOp.getResult(0); if (failed(computeOp))
return failure();
return computeOp->getResult(0);
} }
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> { struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
@@ -108,7 +121,10 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value input = adaptor.getInput(); Value input = adaptor.getInput();
Value result; Value result;
if (*axis == inputType.getRank() - 1) { if (*axis == inputType.getRank() - 1) {
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); auto computed = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
if (failed(computed))
return failure();
result = *computed;
} }
else { else {
SmallVector<int64_t> permutation; SmallVector<int64_t> permutation;
@@ -122,8 +138,10 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
auto transposedType = RankedTensorType::get( auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc()); Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc()); if (failed(transposedResult))
return failure();
result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
} }
rewriter.replaceOp(softmaxOp, result); rewriter.replaceOp(softmaxOp, result);
@@ -9,6 +9,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -192,13 +193,12 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter.setInsertionPointAfter(compute); rewriter.setInsertionPointAfter(compute);
auto newCompute = auto laneCountAttr = pim::getCheckedI32Attr(
spatial::SpatComputeBatch::create(rewriter, rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
compute.getLoc(), if (failed(laneCountAttr))
compute.getResultTypes(), return failure();
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())), auto newCompute = spatial::SpatComputeBatch::create(
promoted->newWeights, rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
promoted->newInputs);
auto laneArg = compute.getLaneArgument(); auto laneArg = compute.getLaneArgument();
if (!laneArg) if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
@@ -5,6 +5,7 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -26,7 +27,7 @@ static Value buildNearestAsymmetricIndex(
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast); return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
} }
static Value buildNearestResizeLoop(Value input, static FailureOr<Value> buildNearestResizeLoop(Value input,
RankedTensorType inputType, RankedTensorType inputType,
RankedTensorType resultType, RankedTensorType resultType,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
@@ -48,54 +49,94 @@ static Value buildNearestResizeLoop(Value input,
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType); Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit}); auto batchLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(batchLoop.getBody()); rewriter,
loc,
c0,
cOutputN,
c1,
ValueRange {outputInit},
[&](OpBuilder&, Location nestedLoc, Value outputN, ValueRange batchIterArgs, SmallVectorImpl<Value>& batchYielded) {
Value outputBatchAcc = batchIterArgs.front();
Value inputN =
buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, nestedLoc);
Value outputN = batchLoop.getInductionVar(); auto channelLoop = buildNormalizedScfFor(
Value outputBatchAcc = batchLoop.getRegionIterArgs().front(); rewriter,
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc); nestedLoc,
c0,
cOutputC,
c1,
ValueRange {outputBatchAcc},
[&](OpBuilder&,
Location channelLoc,
Value outputC,
ValueRange channelIterArgs,
SmallVectorImpl<Value>& channelYielded) {
Value outputChannelAcc = channelIterArgs.front();
Value inputC = buildNearestAsymmetricIndex(
outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, channelLoc);
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc}); auto heightLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(channelLoop.getBody()); rewriter,
channelLoc,
c0,
cOutputH,
c1,
ValueRange {outputChannelAcc},
[&](OpBuilder&,
Location heightLoc,
Value outputH,
ValueRange heightIterArgs,
SmallVectorImpl<Value>& heightYielded) {
Value outputHeightAcc = heightIterArgs.front();
Value inputH = buildNearestAsymmetricIndex(
outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, heightLoc);
Value outputC = channelLoop.getInductionVar(); auto widthLoop = buildNormalizedScfFor(
Value outputChannelAcc = channelLoop.getRegionIterArgs().front(); rewriter,
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc); heightLoc,
c0,
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc}); cOutputW,
rewriter.setInsertionPointToStart(heightLoop.getBody()); c1,
ValueRange {outputHeightAcc},
Value outputH = heightLoop.getInductionVar(); [&](OpBuilder&,
Value outputHeightAcc = heightLoop.getRegionIterArgs().front(); Location widthLoc,
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc); Value outputW,
ValueRange widthIterArgs,
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc}); SmallVectorImpl<Value>& widthYielded) {
rewriter.setInsertionPointToStart(widthLoop.getBody()); Value outputWidthAcc = widthIterArgs.front();
Value inputW = buildNearestAsymmetricIndex(
Value outputW = widthLoop.getInductionVar(); outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, widthLoc);
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW}; SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice = Value inputSlice = tensor::ExtractSliceOp::create(
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides); rewriter, widthLoc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW}; SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
Value updatedOutput = Value updatedOutput = tensor::InsertSliceOp::create(
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides); rewriter, widthLoc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
scf::YieldOp::create(rewriter, loc, updatedOutput); widthYielded.push_back(updatedOutput);
return success();
rewriter.setInsertionPointAfter(widthLoop); });
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0)); if (failed(widthLoop))
return failure();
rewriter.setInsertionPointAfter(heightLoop); heightYielded.push_back(widthLoop->results.front());
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0)); return success();
});
rewriter.setInsertionPointAfter(channelLoop); if (failed(heightLoop))
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0)); return failure();
channelYielded.push_back(heightLoop->results.front());
rewriter.setInsertionPointAfter(batchLoop); return success();
return batchLoop.getResult(0); });
if (failed(channelLoop))
return failure();
batchYielded.push_back(channelLoop->results.front());
return success();
});
if (failed(batchLoop))
return failure();
return batchLoop->results.front();
} }
struct Resize : OpConversionPattern<ONNXResizeOp> { struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -120,12 +161,17 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions."); return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
auto computeOp = auto computeOp = createSpatCompute<1>(
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) -> LogicalResult {
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc()); auto result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); if (failed(result))
return failure();
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), *result);
return success();
}); });
rewriter.replaceOp(resizeOp, computeOp.getResults()); if (failed(computeOp))
return failure();
rewriter.replaceOp(resizeOp, computeOp->getResults());
return success(); return success();
} }
}; };
@@ -10,6 +10,7 @@
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -25,14 +26,21 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
}); });
} }
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount())); coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++)); auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
if (failed(checkedCoreId))
return failure();
coreIds.push_back(*checkedCoreId);
++fallbackCoreId;
}
return coreIds; return coreIds;
} }
@@ -102,21 +110,24 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
"resultful compute_batch lowering currently requires a spat.in_parallel terminator"); "resultful compute_batch lowering currently requires a spat.in_parallel terminator");
} }
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
if (failed(coreIds))
return failure();
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs; SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty()) if (!computeBatchOp.getInputs().empty())
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
rewriter.setInsertionPointAfter(computeBatchOp); rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, auto laneCountAttr = pim::getCheckedI32Attr(
loc, rewriter, computeBatchOp, static_cast<uint64_t>(computeBatchOp.getLaneCount()), "pim core_batch lane count");
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), if (failed(laneCountAttr))
ValueRange(batchWeights), return failure();
ValueRange(batchInputs)); auto coreBatchOp =
pim::PimCoreBatchOp::create(rewriter, loc, *laneCountAttr, ValueRange(batchWeights), ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes( coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())}); {static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
SmallVector<unsigned> returnOperandIndices; SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) { if (computeBatchOp.getNumResults() != 0) {
@@ -160,14 +171,11 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
auto newArgType = cast<ShapedType>(newArg.getType()); auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter, auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), newArg);
loc, if (failed(sizeAttr))
outputBuffer.getType(), return failure();
zeroOffset, auto copied = pim::PimMemCopyHostToDevOp::create(
zeroOffset, rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, newArg, *sizeAttr)
outputBuffer,
newArg,
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput(); .getOutput();
mapper.map(*oldArg, copied); mapper.map(*oldArg, copied);
} }
@@ -209,6 +217,9 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
auto hostTargetType = cast<ShapedType>(hostTarget.getType()); auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
if (failed(sizeAttr))
return failure();
pim::PimMemCopyDevToHostOp::create(rewriter, pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(), insertSlice.getLoc(),
hostTarget.getType(), hostTarget.getType(),
@@ -216,7 +227,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
zeroOffset, zeroOffset,
hostTarget, hostTarget,
mappedSource, mappedSource,
getTensorSizeInBytesAttr(rewriter, mappedSource)); *sizeAttr);
} }
continue; continue;
} }
@@ -232,14 +243,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
auto clonedType = cast<ShapedType>(clonedTensor.getType()); auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter, auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), clonedTensor);
loc, if (failed(sizeAttr))
outputBuffer.getType(), return failure();
zeroOffset, auto copied =
zeroOffset, pim::PimMemCopyHostToDevOp::create(
outputBuffer, rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, clonedTensor, *sizeAttr)
clonedTensor,
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput(); .getOutput();
mapper.map(toTensorOp.getResult(), copied); mapper.map(toTensorOp.getResult(), copied);
continue; continue;
+6 -2
View File
@@ -5,14 +5,18 @@
#include <cassert> #include <cassert>
#include "Common.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { FailureOr<IntegerAttr> getTensorSizeInBytesAttr(Builder& builder, Operation* anchor, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType())))); auto byteSize = pim::getCheckedShapedTypeSizeInBytes(cast<ShapedType>(value.getType()), anchor, "tensor byte size");
if (failed(byteSize))
return failure();
return pim::getCheckedI32Attr(builder, anchor, *byteSize, "tensor byte size");
} }
Operation* getEarliestUserWithinBlock(mlir::Value value) { Operation* getEarliestUserWithinBlock(mlir::Value value) {
+3 -1
View File
@@ -1,12 +1,14 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir { namespace onnx_mlir {
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); mlir::FailureOr<mlir::IntegerAttr>
getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Operation* anchor, mlir::Value value);
template <class T> template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) { size_t rangeLength(const mlir::iterator_range<T> range) {
@@ -9,6 +9,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -54,10 +55,15 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
} }
} }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt()); return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
return static_cast<int32_t>(fallbackCoreId++); auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeOp, "fallback spatial compute core id");
if (failed(checkedCoreId))
return failure();
++fallbackCoreId;
return *checkedCoreId;
} }
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
@@ -163,10 +169,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
auto outputType = cast<ShapedType>(blockArg->getType()); auto outputType = cast<ShapedType>(blockArg->getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, computeOp.getOperation(), *blockArg);
if (failed(sizeAttr))
return failure();
Value received = Value received =
PimReceiveOp::create( PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId()) rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, *sizeAttr, receiveOp.getSourceCoreId())
.getOutput(); .getOutput();
blockArg->replaceAllUsesWith(received); blockArg->replaceAllUsesWith(received);
markOpToRemove(receiveOp); markOpToRemove(receiveOp);
@@ -206,8 +214,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
if (!computeOp.getWeights().empty()) if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create( auto checkedCoreId = getPimCoreIdForComputeOp(computeOp, coreId);
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); if (failed(checkedCoreId))
return failure();
auto coreIdAttr = pim::getCheckedI32Attr(rewriter, computeOp, static_cast<int64_t>(*checkedCoreId), "pim core id");
if (failed(coreIdAttr))
return failure();
auto coreOp = PimCoreOp::create(rewriter, loc, ValueRange(computeWeights), *coreIdAttr);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks(); auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
@@ -226,6 +239,9 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
if (!inputType) if (!inputType)
return computeOp.emitOpError("expected shaped compute input during pim.core lowering"); return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, computeOp.getOperation(), input);
if (failed(sizeAttr))
return failure();
auto copied = auto copied =
PimMemCopyHostToDevOp::create(rewriter, PimMemCopyHostToDevOp::create(rewriter,
loc, loc,
@@ -234,7 +250,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0), getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
outputBuffer, outputBuffer,
input, input,
getTensorSizeInBytesAttr(rewriter, input)) *sizeAttr)
.getOutput(); .getOutput();
blockArg->replaceAllUsesWith(copied); blockArg->replaceAllUsesWith(copied);
} }
@@ -14,8 +14,10 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create( auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId()); if (failed(sizeAttr))
return failure();
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@@ -32,12 +34,11 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
auto outputType = cast<ShapedType>(op.getResult().getType()); auto outputType = cast<ShapedType>(op.getResult().getType());
Value outputBuffer = Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received = pim::PimReceiveOp::create(rewriter, auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
op.getLoc(), if (failed(sizeAttr))
op.getResult().getType(), return failure();
outputBuffer, Value received = pim::PimReceiveOp::create(
getTensorSizeInBytesAttr(rewriter, op.getResult()), rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
op.getSourceCoreId())
.getOutput(); .getOutput();
rewriter.replaceOp(op, received); rewriter.replaceOp(op, received);
return success(); return success();
@@ -12,6 +12,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -71,6 +72,20 @@ static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<i
return indices; return indices;
} }
static FailureOr<int64_t>
getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* anchor, StringRef fieldName) {
if (elementOffset < 0) {
anchor->emitOpError() << fieldName << " requires a nonnegative element offset";
return failure();
}
auto byteOffset =
pim::checkedMul(static_cast<uint64_t>(elementOffset), static_cast<uint64_t>(elementSize), anchor, fieldName);
if (failed(byteOffset))
return failure();
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) { SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
@@ -360,18 +375,21 @@ static void cloneHelperChain(Value sourceValue,
} }
} }
static Value emitHostCopy(IRRewriter& rewriter, static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
Location loc, Location loc,
Value outputTensor, Value outputTensor,
Value sourceValue, Value sourceValue,
int32_t hostTargetOffset, int64_t hostTargetOffset,
int32_t deviceSourceOffset, int64_t deviceSourceOffset,
int32_t sizeInBytes, uint64_t sizeInBytes,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset); Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset);
Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset); Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
auto sizeAttr = pim::getCheckedI32Attr(rewriter, anchorOp, sizeInBytes, "return-path host copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyDevToHostOp::create(rewriter, return PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
@@ -379,7 +397,7 @@ static Value emitHostCopy(IRRewriter& rewriter,
deviceSourceOffsetValue, deviceSourceOffsetValue,
outputTensor, outputTensor,
sourceValue, sourceValue,
rewriter.getI32IntegerAttr(sizeInBytes)) *sizeAttr)
.getOutput(); .getOutput();
} }
@@ -433,18 +451,15 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
markOpToRemove(op); markOpToRemove(op);
auto storedType = cast<ShapedType>(currentStoredValue.getType()); auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType()); auto byteSize = pim::getCheckedShapedTypeSizeInBytes(storedType, producerOp, "return-path host copy byte size");
if (failed(byteSize))
return ReturnPathLoweringResult::Failure;
if (auto storedOp = currentStoredValue.getDefiningOp()) if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp); rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(rewriter, auto copied = emitHostCopy(rewriter, loc, outputTensor, currentStoredValue, 0, 0, *byteSize, constantFolder);
loc, if (failed(copied))
outputTensor, return ReturnPathLoweringResult::Failure;
currentStoredValue,
0,
0,
static_cast<int32_t>(storedType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
@@ -455,23 +470,25 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
if (isa<func::ReturnOp>(resultUser)) { if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
if (failed(byteSize))
return ReturnPathLoweringResult::Failure;
rewriter.setInsertionPointAfterValue(storedValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter, auto copied = emitHostCopy(rewriter, loc, outputTensor, storedValue, 0, 0, *byteSize, constantFolder);
loc, if (failed(copied))
outputTensor, return ReturnPathLoweringResult::Failure;
storedValue,
0,
0,
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
} }
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
auto storedByteSize =
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "concat return-path copy byte size");
if (failed(storedByteSize))
return ReturnPathLoweringResult::Failure;
for (Operation* concatOp : concatReturnUse->concatChain) for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(concatOp); markOpToRemove(concatOp);
@@ -480,14 +497,13 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType()); auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter, auto hostOffset = getCheckedByteOffset(flatOffset, elementSize, producerOp, "concat return-path host offset");
loc, if (failed(hostOffset))
outputTensor, return ReturnPathLoweringResult::Failure;
storedValue, auto copied =
static_cast<int32_t>(flatOffset * elementSize), emitHostCopy(rewriter, loc, outputTensor, storedValue, *hostOffset, 0, *storedByteSize, constantFolder);
0, if (failed(copied))
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize), return ReturnPathLoweringResult::Failure;
constantFolder);
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
@@ -531,14 +547,18 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
rewriter.setInsertionPointAfter(elementSlice); rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
outputTensor = emitHostCopy(rewriter, auto hostOffset =
loc, getCheckedByteOffset(destinationFlatOffset, elementSize, producerOp, "concat helper return-path host offset");
outputTensor, if (failed(hostOffset))
elementSlice.getResult(), return ReturnPathLoweringResult::Failure;
static_cast<int32_t>(destinationFlatOffset * elementSize), auto elementByteSize = pim::checkedCast<uint64_t>(elementSize, producerOp, "return-path scalar copy byte size");
0, if (failed(elementByteSize))
static_cast<int32_t>(elementSize), return ReturnPathLoweringResult::Failure;
constantFolder); auto copied = emitHostCopy(
rewriter, loc, outputTensor, elementSlice.getResult(), *hostOffset, 0, *elementByteSize, constantFolder);
if (failed(copied))
return ReturnPathLoweringResult::Failure;
outputTensor = *copied;
} }
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
@@ -25,8 +25,9 @@
#include <cassert> #include <cassert>
#include <utility> #include <utility>
#include "Common/PimCommon.hpp"
#include "Common/IR/ConstantUtils.hpp" #include "Common/IR/ConstantUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/Patterns.hpp" #include "Conversion/SpatialToPim/Patterns.hpp"
@@ -75,7 +76,7 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
IntegerAttr {}); IntegerAttr {});
} }
static Value createZeroedDeviceHVector(IRRewriter& rewriter, static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc, Location loc,
RankedTensorType tensorType, RankedTensorType tensorType,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
@@ -83,13 +84,20 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter,
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0); auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType))); auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(tensorType, outputBuffer.getOperation(), "host-to-device zero copy byte size");
if (failed(byteSize))
return failure();
auto sizeAttr =
pim::getCheckedI32Attr(rewriter, outputBuffer.getOperation(), *byteSize, "host-to-device zero copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyHostToDevOp::create( return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr) rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, *sizeAttr)
.getOutput(); .getOutput();
} }
static Value static FailureOr<Value>
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType()); auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape(); ArrayRef<int64_t> shape = vectorType.getShape();
@@ -101,10 +109,18 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto paddedType = RankedTensorType::get( auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); {shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); auto zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0); if (failed(zeroed))
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType))); return failure();
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput(); Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed->getDefiningOp(), 0);
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(vectorType, zeroed->getDefiningOp(), "device padding copy byte size");
if (failed(byteSize))
return failure();
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
} }
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
@@ -234,7 +250,11 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
} }
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); if (failed(enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter))) {
funcOp.emitOpError("failed to enlarge VMM output tensors to crossbar size");
signalPassFailure();
return;
}
replaceReturnWithOutputBuffers(returnOp, rewriter); replaceReturnWithOutputBuffers(returnOp, rewriter);
eraseOpsToRemove(); eraseOpsToRemove();
@@ -271,8 +291,9 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
dumpModule(moduleOp, "pim0"); dumpModule(moduleOp, "pim0");
} }
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext());
bool hasFailure = false;
funcOp.walk([&](PimVMMOp vmmOp) { funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType()); auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape(); ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -280,19 +301,23 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar"); assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp); rewriter.setInsertionPoint(vmmOp);
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); auto paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
if (failed(paddedInput)) {
hasFailure = true;
return WalkResult::interrupt();
}
auto paddedOutputType = RankedTensorType::get( auto paddedOutputType = RankedTensorType::get(
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding()); {outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize) Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
? vmmOp.getOutputBuffer() ? vmmOp.getOutputBuffer()
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult(); : createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
vmmOp.getInputMutable().assign(paddedInput); vmmOp.getInputMutable().assign(*paddedInput);
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer); vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
vmmOp.getOutput().setType(paddedOutputType); vmmOp.getOutput().setType(paddedOutputType);
if (outputShape[1] == static_cast<int64_t>(crossbarSize)) if (outputShape[1] == static_cast<int64_t>(crossbarSize))
return; return WalkResult::advance();
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])}; SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
@@ -302,13 +327,16 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides); tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp}; SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions); vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
return WalkResult::advance();
}); });
return success(!hasFailure);
} }
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
IRRewriter& rewriter) { IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext());
bool hasFailure = false;
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType()); auto tensorType = cast<ShapedType>(inputTensor.getType());
@@ -319,17 +347,28 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto offsetBytes = pim::checkedMul(
static_cast<size_t>(elementsOffset), elementByteSize, deviceTensor.getOperation(), "host input byte offset");
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(tensorType, deviceTensor.getOperation(), "host input copy byte size");
auto sizeAttr =
succeeded(byteSize)
? pim::getCheckedI32Attr(rewriter, deviceTensor.getOperation(), *byteSize, "host input copy byte size")
: FailureOr<IntegerAttr>(failure());
if (failed(offsetBytes) || failed(sizeAttr)) {
hasFailure = true;
return;
}
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter, rewriter,
loc, loc,
tensorType, tensorType,
getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0), getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
getOrCreateIndexConstant( getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), static_cast<int64_t>(*offsetBytes)),
constantFolder, deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize)),
deviceTensor, deviceTensor,
inputTensor, inputTensor,
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize))); *sizeAttr);
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
}; };
@@ -347,7 +386,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
} }
} }
return success(); return success(!hasFailure);
} }
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) { void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
@@ -64,7 +64,7 @@ private:
void markOpToRemove(mlir::Operation* op); void markOpToRemove(mlir::Operation* op);
void eraseOpsToRemove(); void eraseOpsToRemove();
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); mlir::LogicalResult enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
}; };
} // namespace raptor } // namespace raptor
+4 -1
View File
@@ -2,7 +2,10 @@ add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td) add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization) add_subdirectory(Transforms/Bufferization)
add_subdirectory(Transforms/StaticMemoryCoalescing) add_subdirectory(Transforms/MemoryCoalescing)
add_subdirectory(Transforms/HostConstantFolding)
add_subdirectory(Transforms/HostConstantMaterialization)
add_subdirectory(Transforms/Verification)
add_pim_library(PimOps add_pim_library(PimOps
PimOps.hpp PimOps.hpp
@@ -3,6 +3,7 @@
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
@@ -11,24 +12,25 @@ using namespace bufferization;
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
return memrefValue; return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType); auto sizeInBytes =
getCheckedShapedTypeSizeInBytes(shapedType, contiguousBuffer.getDefiningOp(), "contiguous copy byte size");
if (failed(sizeInBytes))
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0); Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
auto sizeAttr =
getCheckedI32Attr(rewriter, contiguousBuffer.getDefiningOp(), *sizeInBytes, "contiguous copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyOp::create(rewriter, return PimMemCopyOp::create(
loc, rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
contiguousType,
zeroOffset,
zeroOffset,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput(); .getOutput();
} }
@@ -5,7 +5,8 @@
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); llvm::FailureOr<mlir::Value>
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
mlir::Value mlir::Value
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
@@ -1,10 +1,13 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace mlir; using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) { FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType()); auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type)); auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
return builder.getI32IntegerAttr(sizeInBytes); if (failed(byteSize))
return failure();
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
} }
@@ -5,7 +5,8 @@
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref); mlir::FailureOr<mlir::IntegerAttr>
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -4,8 +4,10 @@
#include "ContiguityPatterns.hpp" #include "ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace mlir; using namespace mlir;
@@ -85,7 +87,13 @@ static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
static FailureOr<int64_t> getShapedByteSize(MemRefType type) { static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())) if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
return failure(); return failure();
return static_cast<int64_t>(getShapedTypeSizeInBytes(type)); auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "normalized copy byte size");
if (failed(byteSize))
return failure();
if (*byteSize > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
return failure();
return static_cast<int64_t>(*byteSize);
} }
static FailureOr<SmallVector<int64_t>> static FailureOr<SmallVector<int64_t>>
@@ -325,12 +333,11 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
if (plan->kind == CopyRewritePlan::Kind::Direct) { if (plan->kind == CopyRewritePlan::Kind::Direct) {
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset); Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset); Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
auto newCopyOp = createCopyOp(loc, auto checkedDirectBytes = pim::checkedI32(plan->directBytes, anchorOp, "normalized direct copy byte size");
plan->target.base, if (failed(checkedDirectBytes))
plan->source.base, return failure();
newTargetOffset, auto newCopyOp =
newSourceOffset, createCopyOp(loc, plan->target.base, plan->source.base, newTargetOffset, newSourceOffset, *checkedDirectBytes);
static_cast<int32_t>(plan->directBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.replaceOp(copyOp, replacementValue); rewriter.replaceOp(copyOp, replacementValue);
return success(); return success();
@@ -339,23 +346,30 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
Value c0 = createIndexConstant(rewriter, anchorOp, 0); Value c0 = createIndexConstant(rewriter, anchorOp, 0);
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape)); Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
Value cStep = createIndexConstant(rewriter, anchorOp, 1); Value cStep = createIndexConstant(rewriter, anchorOp, 1);
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {}); auto loop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(loop.getBody()); rewriter,
loc,
c0,
cUpper,
cStep,
ValueRange {},
[&](OpBuilder&, Location nestedLoc, Value inductionVar, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
SmallVector<Value> outerIndices = SmallVector<Value> outerIndices =
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape); materializeDelinearizedIndices(rewriter, nestedLoc, anchorOp, inductionVar, plan->loop.outerShape);
Value loopTargetOffset = materializeOuterByteOffset( Value loopTargetOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides); rewriter, nestedLoc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
Value loopSourceOffset = materializeOuterByteOffset( Value loopSourceOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides); rewriter, nestedLoc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
auto newCopyOp = createCopyOp(loc, auto checkedChunkBytes = pim::checkedI32(plan->loop.chunkBytes, anchorOp, "normalized loop copy byte size");
plan->target.base, if (failed(checkedChunkBytes))
plan->source.base, return failure();
loopTargetOffset, auto newCopyOp = createCopyOp(
loopSourceOffset, nestedLoc, plan->target.base, plan->source.base, loopTargetOffset, loopSourceOffset, *checkedChunkBytes);
static_cast<int32_t>(plan->loop.chunkBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy"); assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.setInsertionPointAfter(loop); return success();
});
if (failed(loop))
return failure();
rewriter.replaceOp(copyOp, replacementValue); rewriter.replaceOp(copyOp, replacementValue);
return success(); return success();
} }
@@ -148,7 +148,10 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
auto inputOpt = getBufferOrValue(rewriter, input, options, state); auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter)); auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguous))
return failure();
inputs.push_back(*contiguous);
} }
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
@@ -179,12 +182,12 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter, replaceOpWithNewBufferizedOp<PimSendOp>(
op, rewriter, op, *contiguousInput, sendOp.getSizeAttr(), sendOp.getTargetCoreId());
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreId());
return success(); return success();
} }
}; };
@@ -407,11 +410,13 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimTransposeOp>( replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput); rewriter, op, contiguousOutput.getType(), *contiguousInput, transposeOp.getPermutation(), contiguousOutput);
return success(); return success();
} }
}; };
@@ -451,11 +456,13 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>( replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput); rewriter, op, contiguousOutput.getType(), *weightOpt, *contiguousInput, contiguousOutput);
return success(); return success();
} }
}; };
@@ -490,12 +497,16 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); if (failed(contiguousLhs))
return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
if (failed(contiguousRhs))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>( replaceOpWithNewBufferizedOp<OpTy>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput);
return success(); return success();
} }
}; };
@@ -523,12 +534,16 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); if (failed(contiguousLhs))
return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
if (failed(contiguousRhs))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVVDMulOp>( replaceOpWithNewBufferizedOp<PimVVDMulOp>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput);
return success(); return success();
} }
}; };
@@ -559,10 +574,12 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput); replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), *contiguousInput, contiguousOutput);
return success(); return success();
} }
}; };
@@ -42,7 +42,9 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
return failure(); return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0); Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource()); auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
pim::PimMemCopyOp::create(rewriter, pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(), copyOp.getLoc(),
copyOp.getTarget().getType(), copyOp.getTarget().getType(),
@@ -50,7 +52,7 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
zeroOffset, zeroOffset,
copyOp.getTarget(), copyOp.getTarget(),
copyOp.getSource(), copyOp.getSource(),
sizeAttr); *sizeAttr);
rewriter.eraseOp(copyOp); rewriter.eraseOp(copyOp);
return success(); return success();
} }
@@ -0,0 +1,12 @@
add_pim_library(OMPimHostConstantFolding
Common.cpp
Patterns/Constant.cpp
HostConstantFoldingPass.cpp
Patterns/Subview.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRLinalgDialect
OMPimCommon
)
@@ -7,6 +7,7 @@
#include "../Common.hpp" #include "../Common.hpp"
#include "../Patterns.hpp" #include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -120,16 +121,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = getShapedTypeSizeInBytes(initType); auto sizeInBytes = pim::getCheckedShapedTypeSizeInBytes(initType, mapOp, "host constant folding byte size");
if (failed(sizeInBytes))
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0); Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0);
pim::PimMemCopyOp::create(rewriter, auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
mapOp.getLoc(), if (failed(sizeAttr))
initType, return failure();
zeroOffset, pim::PimMemCopyOp::create(
zeroOffset, rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(mapOp); rewriter.eraseOp(mapOp);
return success(); return success();
} }
@@ -0,0 +1,9 @@
add_pim_library(OMPimHostConstantMaterialization
MaterializeHostConstantsPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -12,6 +12,7 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -65,11 +66,15 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
continue; continue;
} }
int64_t totalBytes = -1; auto type = dyn_cast<ShapedType>(originalValue.getType());
if (auto type = dyn_cast<ShapedType>(originalValue.getType()); type && type.hasStaticShape()) auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
totalBytes = static_cast<int64_t>(getShapedTypeSizeInBytes(type)); : FailureOr<uint64_t>(failure());
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { auto totalBytesAttr =
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); succeeded(totalBytes)
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
: FailureOr<IntegerAttr>(failure());
if (failed(totalBytesAttr)
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
hasFailure = true; hasFailure = true;
continue; continue;
} }
@@ -84,15 +89,14 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0); Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset); Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
Value copiedValue = Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(), op->getLoc(),
originalType, originalType,
zeroOffset, zeroOffset,
hostOffset, hostOffset,
deviceDst, deviceDst,
getGlobalOp.getResult(), getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes))) *totalBytesAttr)
.getOutput(); .getOutput();
cachedByType[originalType] = copiedValue; cachedByType[originalType] = copiedValue;
@@ -0,0 +1,14 @@
add_pim_library(OMPimMemoryCoalescing
MemoryCoalescing.cpp
MemoryCoalescing.hpp
MemoryCoalescingPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -10,7 +10,7 @@
#include <limits> #include <limits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp"
using namespace mlir; using namespace mlir;
@@ -32,6 +32,13 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType())); return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
} }
static Operation* getTopLevelAncestorInBody(Operation* op, Block& body) {
Operation* current = op;
while (current && current->getBlock() != &body)
current = current->getParentOp();
return current;
}
static FailureOr<uint64_t> static FailureOr<uint64_t>
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) { getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp); uint64_t endInstruction = opOrder.lookup(allocOp);
@@ -42,7 +49,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
while (!pendingValues.empty()) { while (!pendingValues.empty()) {
Value value = pendingValues.pop_back_val(); Value value = pendingValues.pop_back_val();
for (Operation* user : value.getUsers()) { for (Operation* user : value.getUsers()) {
if (user->getBlock() != &body) Operation* orderedUser = getTopLevelAncestorInBody(user, body);
if (!orderedUser)
return failure(); return failure();
if (!visited.insert(user).second) if (!visited.insert(user).second)
continue; continue;
@@ -51,6 +59,15 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
for (Value result : user->getResults()) for (Value result : user->getResults())
pendingValues.push_back(result); pendingValues.push_back(result);
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return failure();
for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands()))
if (operand == value)
pendingValues.push_back(forOp.getResult(index));
}
if (auto forOp = dyn_cast<scf::ForOp>(user)) { if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
if (initArg == value) if (initArg == value)
@@ -66,7 +83,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
} }
} }
auto order = opOrder.find(user); auto order = opOrder.find(orderedUser);
if (order == opOrder.end()) if (order == opOrder.end())
return failure(); return failure();
endInstruction = std::max(endInstruction, order->second); endInstruction = std::max(endInstruction, order->second);
@@ -78,8 +95,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
} // namespace } // namespace
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation* coreLikeOp) { MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation* coreLikeOp) {
StaticMemoryCoalescingAnalysis analysis; MemoryCoalescingAnalysis analysis;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty()) if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return analysis; return analysis;
@@ -107,18 +124,19 @@ StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation
} }
analysis.candidates.push_back( analysis.candidates.push_back(
StaticAllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)}); AllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
} }
return analysis; return analysis;
} }
StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, RewriterBase& rewriter) { MemoryCoalescingStats
StaticMemoryCoalescingStats stats; coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) {
auto analysis = analyzeStaticMemoryCoalescingCandidates(coreLikeOp); MemoryCoalescingStats stats;
stats.skippedAllocations = analysis.skippedAllocations; stats.skippedAllocations = analysis.skippedAllocations;
llvm::sort(analysis.candidates, [](const StaticAllocationCandidate& lhs, const StaticAllocationCandidate& rhs) { auto candidates = analysis.candidates;
llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) {
if (lhs.startInstruction != rhs.startInstruction) if (lhs.startInstruction != rhs.startInstruction)
return lhs.startInstruction < rhs.startInstruction; return lhs.startInstruction < rhs.startInstruction;
return lhs.endInstruction < rhs.endInstruction; return lhs.endInstruction < rhs.endInstruction;
@@ -132,7 +150,7 @@ StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, Rewriter
SmallVector<ActiveStorage> active; SmallVector<ActiveStorage> active;
SmallVector<memref::AllocOp> freeList; SmallVector<memref::AllocOp> freeList;
for (StaticAllocationCandidate& candidate : analysis.candidates) { for (AllocationCandidate& candidate : candidates) {
for (auto it = active.begin(); it != active.end();) { for (auto it = active.begin(); it != active.end();) {
if (it->endInstruction < candidate.startInstruction) { if (it->endInstruction < candidate.startInstruction) {
freeList.push_back(it->root); freeList.push_back(it->root);
@@ -8,27 +8,28 @@
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
struct StaticAllocationCandidate { struct AllocationCandidate {
mlir::memref::AllocOp alloc; mlir::memref::AllocOp alloc;
uint64_t startInstruction = 0; uint64_t startInstruction = 0;
uint64_t endInstruction = 0; uint64_t endInstruction = 0;
uint64_t sizeBytes = 0; uint64_t sizeBytes = 0;
}; };
struct StaticMemoryCoalescingAnalysis { struct MemoryCoalescingAnalysis {
llvm::SmallVector<StaticAllocationCandidate> candidates; llvm::SmallVector<AllocationCandidate> candidates;
uint64_t skippedAllocations = 0; uint64_t skippedAllocations = 0;
}; };
struct StaticMemoryCoalescingStats { struct MemoryCoalescingStats {
uint64_t removedAllocs = 0; uint64_t removedAllocs = 0;
uint64_t savedBytes = 0; uint64_t savedBytes = 0;
uint64_t skippedAllocations = 0; uint64_t skippedAllocations = 0;
}; };
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(mlir::Operation* coreLikeOp); MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
StaticMemoryCoalescingStats coalesceStaticMemory(mlir::Operation* coreLikeOp, mlir::RewriterBase& rewriter); MemoryCoalescingStats
coalesceMemory(mlir::Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, mlir::RewriterBase& rewriter);
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -10,10 +10,11 @@
#include "Common/IR/CompactAsmUtils.hpp" #include "Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" #include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir; using namespace mlir;
@@ -151,34 +152,39 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
file.close(); file.close();
} }
struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, OperationPass<ModuleOp>> { struct PimMemoryCoalescingPass : PassWrapper<PimMemoryCoalescingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticMemoryCoalescingPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMemoryCoalescingPass)
StringRef getArgument() const override { return "pim-static-memory-coalescing"; } StringRef getArgument() const override { return "pim-memory-coalescing"; }
StringRef getDescription() const override { return "Analyze static local PIM memory reuse opportunities"; } StringRef getDescription() const override { return "Analyze local PIM memory reuse opportunities"; }
StaticMemoryCoalescingPass() = default; PimMemoryCoalescingPass() = default;
StaticMemoryCoalescingPass(const StaticMemoryCoalescingPass& pass) {} PimMemoryCoalescingPass(const PimMemoryCoalescingPass& pass) {}
void runOnOperation() override { void runOnOperation() override {
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
SmallVector<CoalescingReportEntry, 32> reportEntries; SmallVector<CoalescingReportEntry, 32> reportEntries;
uint64_t nextBatchId = 0; uint64_t nextBatchId = 0;
bool hasFailure = false;
getOperation().walk([&](Operation* op) { getOperation().walk([&](Operation* op) {
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op)) if (hasFailure || !isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
return; return;
auto analysis = pim::analyzeStaticMemoryCoalescingCandidates(op); auto analysis = pim::analyzeMemoryCoalescingCandidates(op);
auto stats = pim::coalesceStaticMemory(op, rewriter); auto stats = pim::coalesceMemory(op, analysis, rewriter);
CoalescingReportRow row { CoalescingReportRow row {
analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes}; analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) { if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
reportEntries.push_back({CoalescingReportEntry::Kind::Core, auto checkedCoreId =
static_cast<uint64_t>(coreOp.getCoreId()), pim::checkedI32(static_cast<uint64_t>(coreOp.getCoreId()), coreOp, "memory coalescing core id");
{static_cast<int32_t>(coreOp.getCoreId())}, if (failed(checkedCoreId)) {
row}); hasFailure = true;
return;
}
reportEntries.push_back(
{CoalescingReportEntry::Kind::Core, static_cast<uint64_t>(coreOp.getCoreId()), {*checkedCoreId}, row});
return; return;
} }
@@ -191,6 +197,11 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
reportEntries.push_back(std::move(entry)); reportEntries.push_back(std::move(entry));
}); });
if (hasFailure) {
signalPassFailure();
return;
}
emitReport(reportEntries); emitReport(reportEntries);
dumpModule(getOperation(), "pim2_coalesced"); dumpModule(getOperation(), "pim2_coalesced");
} }
@@ -198,6 +209,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
} // namespace } // namespace
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); } std::unique_ptr<Pass> createPimMemoryCoalescingPass() { return std::make_unique<PimMemoryCoalescingPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,14 +0,0 @@
add_pim_library(OMPimStaticMemoryCoalescing
StaticMemoryCoalescing.cpp
StaticMemoryCoalescing.hpp
StaticMemoryCoalescingPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -0,0 +1,11 @@
add_pim_library(OMPimVerification
VerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
OMPimBufferization
PimOps
SpatialOps
)
@@ -23,7 +23,9 @@
#include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -36,6 +38,23 @@ using CpuId = size_t;
using ClassId = size_t; using ClassId = size_t;
using SlotId = size_t; using SlotId = size_t;
static FailureOr<int32_t> getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) {
return pim::checkedI32(static_cast<uint64_t>(cpu), anchor, fieldName);
}
static FailureOr<SmallVector<int32_t, 8>>
getCheckedCoreIds(Operation* anchor, ArrayRef<CpuId> cpus, StringRef fieldName) {
SmallVector<int32_t, 8> coreIds;
coreIds.reserve(cpus.size());
for (CpuId cpu : cpus) {
auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName);
if (failed(checkedCoreId))
return failure();
coreIds.push_back(*checkedCoreId);
}
return coreIds;
}
struct ProducerKey { struct ProducerKey {
ComputeInstance instance; ComputeInstance instance;
size_t resultIndex = 0; size_t resultIndex = 0;
@@ -498,7 +517,7 @@ LogicalResult collectHostOutputs(MaterializerState& state) {
return success(); return success();
} }
void createEmptyMaterializedOps(MaterializerState& state) { LogicalResult createEmptyMaterializedOps(MaterializerState& state) {
Location loc = state.func.getLoc(); Location loc = state.func.getLoc();
Block& funcBlock = state.func.getBody().front(); Block& funcBlock = state.func.getBody().front();
@@ -524,8 +543,11 @@ void createEmptyMaterializedOps(MaterializerState& state) {
if (!materializedClass.isBatch) { if (!materializedClass.isBatch) {
auto compute = SpatCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); auto compute = SpatCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {});
compute.getProperties().setOperandSegmentSizes({0, 0}); compute.getProperties().setOperandSegmentSizes({0, 0});
compute->setAttr(onnx_mlir::kCoreIdAttrName, auto coreIdAttr =
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.front()))); pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id");
if (failed(coreIdAttr))
return failure();
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
Block* body = state.rewriter.createBlock(&compute.getBody()); Block* body = state.rewriter.createBlock(&compute.getBody());
state.rewriter.setInsertionPointToEnd(body); state.rewriter.setInsertionPointToEnd(body);
SmallVector<Value, 4> placeholderOutputs; SmallVector<Value, 4> placeholderOutputs;
@@ -534,7 +556,7 @@ void createEmptyMaterializedOps(MaterializerState& state) {
auto tensorType = dyn_cast<RankedTensorType>(resultType); auto tensorType = dyn_cast<RankedTensorType>(resultType);
if (!tensorType || !tensorType.hasStaticShape()) { if (!tensorType || !tensorType.hasStaticShape()) {
compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); compute.emitOpError("host-facing materialized compute results must be static ranked tensors");
continue; return failure();
} }
placeholderOutputs.push_back( placeholderOutputs.push_back(
tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult());
@@ -546,19 +568,17 @@ void createEmptyMaterializedOps(MaterializerState& state) {
continue; continue;
} }
auto batch = auto batchLaneCountAttr = pim::getCheckedI32Attr(
SpatComputeBatch::create(state.rewriter, state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count");
loc, if (failed(batchLaneCountAttr))
TypeRange(resultTypes), return failure();
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())), auto batch = SpatComputeBatch::create(
ValueRange {}, state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {});
ValueRange {});
batch.getProperties().setOperandSegmentSizes({0, 0}); batch.getProperties().setOperandSegmentSizes({0, 0});
SmallVector<int32_t> coreIds; auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id");
coreIds.reserve(materializedClass.cpus.size()); if (failed(coreIds))
for (CpuId cpu : materializedClass.cpus) return failure();
coreIds.push_back(static_cast<int32_t>(cpu)); batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(*coreIds));
batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type, 4> blockArgTypes {state.rewriter.getIndexType()}; SmallVector<Type, 4> blockArgTypes {state.rewriter.getIndexType()};
SmallVector<Location, 4> blockArgLocs {loc}; SmallVector<Location, 4> blockArgLocs {loc};
@@ -575,6 +595,8 @@ void createEmptyMaterializedOps(MaterializerState& state) {
materializedClass.body = body; materializedClass.body = body;
state.rewriter.setInsertionPointAfter(batch.getOperation()); state.rewriter.setInsertionPointAfter(batch.getOperation());
} }
return success();
} }
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
@@ -787,7 +809,7 @@ SmallVector<int64_t, 16> widenToI64(ArrayRef<int32_t> values) {
return widened; return widened;
} }
Value createReceiveConcatLoop(MaterializerState& state, FailureOr<Value> createReceiveConcatLoop(MaterializerState& state,
Operation* anchor, Operation* anchor,
Operation* insertionPoint, Operation* insertionPoint,
RankedTensorType concatType, RankedTensorType concatType,
@@ -853,7 +875,7 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
SmallVector<int64_t, 16> sourceCoreIds = widenToI64(run.sourceCoreIds); SmallVector<int64_t, 16> sourceCoreIds = widenToI64(run.sourceCoreIds);
SmallVector<int64_t, 16> targetCoreIds = widenToI64(run.targetCoreIds); SmallVector<int64_t, 16> targetCoreIds = widenToI64(run.targetCoreIds);
run.packed = createReceiveConcatLoop(state, auto packed = createReceiveConcatLoop(state,
targetClass.op, targetClass.op,
targetClass.body->getTerminator(), targetClass.body->getTerminator(),
*fullPackedType, *fullPackedType,
@@ -862,6 +884,9 @@ FailureOr<Value> materializePackedScalarRunValue(MaterializerState& state,
sourceCoreIds, sourceCoreIds,
targetCoreIds, targetCoreIds,
loc); loc);
if (failed(packed))
return failure();
run.packed = *packed;
return run.packed; return run.packed;
} }
@@ -1559,7 +1584,7 @@ void appendScalarSend(MaterializerState& state,
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
} }
void appendScalarSendLoop(MaterializerState& state, LogicalResult appendScalarSendLoop(MaterializerState& state,
MaterializedClass& sourceClass, MaterializedClass& sourceClass,
Value payload, Value payload,
ArrayRef<int64_t> channelIds, ArrayRef<int64_t> channelIds,
@@ -1578,20 +1603,26 @@ void appendScalarSendLoop(MaterializerState& state,
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size())); getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto sendLoop = buildNormalizedScfFor(
state.rewriter,
OpBuilder::InsertionGuard guard(state.rewriter); loc,
state.rewriter.setInsertionPointToStart(loop.getBody()); lowerBound,
upperBound,
Value index = loop.getInductionVar(); step,
ValueRange {},
[&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl<Value>&) {
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
return success();
});
if (failed(sendLoop))
return failure();
return success();
} }
Value buildProjectedPackedPayload(MaterializerState& state, FailureOr<Value> buildProjectedPackedPayload(MaterializerState& state,
Operation* anchor, Operation* anchor,
Value fullPayload, Value fullPayload,
const ProjectedTransferDescriptor& descriptor, const ProjectedTransferDescriptor& descriptor,
@@ -1607,19 +1638,15 @@ Value buildProjectedPackedPayload(MaterializerState& state,
Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane);
Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange {init},
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body); Value acc = iterArgs.front();
Value fragmentIndex = loop.getInductionVar();
Value acc = body->getArgument(1);
Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane);
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
@@ -1628,14 +1655,18 @@ Value buildProjectedPackedPayload(MaterializerState& state,
Value fragment = createSingleDimExtractSlice( Value fragment = createSingleDimExtractSlice(
state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape()); state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape());
Value packedOffset = scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc); Value packedOffset =
scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc);
Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset); Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset);
scf::YieldOp::create(state.rewriter, loc, next); yielded.push_back(next);
return success();
return loop.getResult(0); });
if (failed(loop))
return failure();
return loop->results.front();
} }
void appendProjectedScalarSendLoop(MaterializerState& state, LogicalResult appendProjectedScalarSendLoop(MaterializerState& state,
MaterializedClass& sourceClass, MaterializedClass& sourceClass,
Value payload, Value payload,
const ProjectedTransferDescriptor& descriptor, const ProjectedTransferDescriptor& descriptor,
@@ -1664,11 +1695,14 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape());
} }
else { else {
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc); auto packedPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc);
if (failed(packedPayload))
return failure();
sendPayload = *packedPayload;
} }
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
return; return success();
} }
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
@@ -1676,12 +1710,14 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size())); getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto projectedSendLoop = buildNormalizedScfFor(
state.rewriter,
OpBuilder::InsertionGuard guard(state.rewriter); loc,
state.rewriter.setInsertionPointToStart(loop.getBody()); lowerBound,
upperBound,
Value index = loop.getInductionVar(); step,
ValueRange {},
[&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl<Value>&) {
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
@@ -1693,13 +1729,21 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape());
} }
else { else {
sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc); auto packedPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc);
if (failed(packedPayload))
return failure();
sendPayload = *packedPayload;
} }
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload);
return success();
});
if (failed(projectedSendLoop))
return failure();
return success();
} }
void appendSend(MaterializerState& state, LogicalResult appendSend(MaterializerState& state,
MaterializedClass& sourceClass, MaterializedClass& sourceClass,
Value payload, Value payload,
ArrayRef<int64_t> channelIds, ArrayRef<int64_t> channelIds,
@@ -1717,16 +1761,16 @@ void appendSend(MaterializerState& state,
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
return; return success();
} }
if (channelIds.size() == 1) { if (channelIds.size() == 1) {
appendScalarSend( appendScalarSend(
state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
return; return success();
} }
appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); return appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
} }
Value appendScalarReceive(MaterializerState& state, Value appendScalarReceive(MaterializerState& state,
@@ -1854,7 +1898,7 @@ SmallVector<ClassId, 4> collectDestinationClassesForKeys(MaterializerState& stat
return destinations; return destinations;
} }
SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState& state, FailureOr<SmallVector<ScalarSourceReceivePlan, 4>> emitScalarSourceSends(MaterializerState& state,
MaterializedClass& sourceClass, MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys, ArrayRef<ProducerKey> keys,
ArrayRef<ClassId> destinationClasses, ArrayRef<ClassId> destinationClasses,
@@ -1862,7 +1906,9 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
Location loc) { Location loc) {
assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class");
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front()); auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id");
if (failed(sourceCpu))
return failure();
SmallVector<ScalarSourceReceivePlan, 4> receivePlans; SmallVector<ScalarSourceReceivePlan, 4> receivePlans;
receivePlans.reserve(destinationClasses.size()); receivePlans.reserve(destinationClasses.size());
@@ -1870,7 +1916,7 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
const auto tryEmitProjected = [&](ClassId destinationClass, const auto tryEmitProjected = [&](ClassId destinationClass,
const SmallVector<int64_t, 8>& channelIds, const SmallVector<int64_t, 8>& channelIds,
const SmallVector<int32_t, 8>& sourceCoreIds, const SmallVector<int32_t, 8>& sourceCoreIds,
const SmallVector<int32_t, 8>& targetCoreIds) -> bool { const SmallVector<int32_t, 8>& targetCoreIds) -> FailureOr<bool> {
if (keys.size() != 1) if (keys.size() != 1)
return false; return false;
@@ -1891,8 +1937,9 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
!= targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane)) != targetClass.cpus.size() * static_cast<size_t>(descriptor.fragmentsPerLane))
return false; return false;
appendProjectedScalarSendLoop( if (failed(appendProjectedScalarSendLoop(
state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc); state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc)))
return failure();
Value received = Value received =
appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc); appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc);
@@ -1911,24 +1958,36 @@ SmallVector<ScalarSourceReceivePlan, 4> emitScalarSourceSends(MaterializerState&
ScalarSourceReceivePlan plan; ScalarSourceReceivePlan plan;
plan.targetClass = destinationClass; plan.targetClass = destinationClass;
auto appendMessage = [&](int32_t targetCpu) { auto appendMessage = [&](CpuId targetCpu) -> LogicalResult {
auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id");
if (failed(checkedTargetCpu))
return failure();
int64_t channelId = state.nextChannelId++; int64_t channelId = state.nextChannelId++;
plan.channelIds.push_back(channelId); plan.channelIds.push_back(channelId);
plan.sourceCoreIds.push_back(sourceCpu); plan.sourceCoreIds.push_back(*sourceCpu);
plan.targetCoreIds.push_back(targetCpu); plan.targetCoreIds.push_back(*checkedTargetCpu);
return success();
}; };
if (!targetClass.isBatch) if (!targetClass.isBatch) {
appendMessage(static_cast<int32_t>(targetClass.cpus.front())); if (failed(appendMessage(targetClass.cpus.front())))
else return failure();
}
else {
for (CpuId targetCpu : targetClass.cpus) for (CpuId targetCpu : targetClass.cpus)
appendMessage(static_cast<int32_t>(targetCpu)); if (failed(appendMessage(targetCpu)))
return failure();
}
if (tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds)) auto emittedProjected = tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds);
if (failed(emittedProjected))
return failure();
if (*emittedProjected)
continue; continue;
appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); if (failed(appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc)))
return failure();
receivePlans.push_back(std::move(plan)); receivePlans.push_back(std::move(plan));
} }
@@ -1943,10 +2002,11 @@ LogicalResult emitScalarSourceCommunication(
state.availableValues.record(key, sourceClass.id, payload); state.availableValues.record(key, sourceClass.id, payload);
SmallVector<ClassId, 4> destinationClasses = collectDestinationClassesForKeys(state, keys); SmallVector<ClassId, 4> destinationClasses = collectDestinationClassesForKeys(state, keys);
SmallVector<ScalarSourceReceivePlan, 4> receivePlans = auto receivePlans = emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc);
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc); if (failed(receivePlans))
return failure();
for (const ScalarSourceReceivePlan& plan : receivePlans) { for (const ScalarSourceReceivePlan& plan : *receivePlans) {
MaterializedClass& targetClass = state.classes[plan.targetClass]; MaterializedClass& targetClass = state.classes[plan.targetClass];
Value received = appendReceive( Value received = appendReceive(
@@ -1987,14 +2047,20 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
sourceCoreIds.reserve(sourceClass.cpus.size()); sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size()); targetCoreIds.reserve(sourceClass.cpus.size());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front()); auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id");
if (failed(targetCpu))
return failure();
for (CpuId sourceCpu : sourceClass.cpus) { for (CpuId sourceCpu : sourceClass.cpus) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id");
if (failed(checkedSourceCpu))
return failure();
channelIds.push_back(state.nextChannelId++); channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu)); sourceCoreIds.push_back(*checkedSourceCpu);
targetCoreIds.push_back(targetCpu); targetCoreIds.push_back(*targetCpu);
} }
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); if (failed(appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc)))
return failure();
return registerLazyPackedScalarReceives( return registerLazyPackedScalarReceives(
state, sourceClass, targetClass, keys, payload.getType(), channelIds, sourceCoreIds, targetCoreIds); state, sourceClass, targetClass, keys, payload.getType(), channelIds, sourceCoreIds, targetCoreIds);
} }
@@ -2011,12 +2077,19 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
targetCoreIds.reserve(targetClass.cpus.size()); targetCoreIds.reserve(targetClass.cpus.size());
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id");
if (failed(checkedSourceCpu))
return failure();
auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id");
if (failed(checkedTargetCpu))
return failure();
channelIds.push_back(state.nextChannelId++); channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu)); sourceCoreIds.push_back(*checkedSourceCpu);
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane])); targetCoreIds.push_back(*checkedTargetCpu);
} }
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); if (failed(appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc)))
return failure();
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys) for (ProducerKey key : keys)
@@ -2230,35 +2303,30 @@ FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState&
Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(keys.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(keys.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange {init},
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body); Value acc = iterArgs.front();
Value loopIndex = loop.getInductionVar();
Value acc = body->getArgument(1);
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced = FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
if (failed(produced)) if (failed(produced) || produced->size() != 1)
return failure();
if (produced->size() != 1)
return failure(); return failure();
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, run.fragmentType.getDimSize(0), loc); Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, run.fragmentType.getDimSize(0), loc);
Value next = createDim0InsertSlice(state, loc, produced->front(), acc, firstOffset); Value next = createDim0InsertSlice(state, loc, produced->front(), acc, firstOffset);
yielded.push_back(next);
scf::YieldOp::create(state.rewriter, loc, next); return success();
});
run.packed = loop.getResult(0); if (failed(loop))
return failure();
run.packed = loop->results.front();
return run.packed; return run.packed;
} }
@@ -2297,34 +2365,30 @@ FailureOr<Value> insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange {destination},
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body); Value acc = iterArgs.front();
Value loopIndex = loop.getInductionVar();
Value acc = body->getArgument(1);
Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced = FailureOr<SmallVector<Value, 4>> produced =
cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex);
if (failed(produced)) if (failed(produced) || produced->size() != 1)
return failure();
if (produced->size() != 1)
return failure(); return failure();
Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, loopIndex, loc); Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, loopIndex, loc);
Value next = insertFragmentIntoWholeBatch(state, produced->front(), acc, outputOffset, loc); Value next = insertFragmentIntoWholeBatch(state, produced->front(), acc, outputOffset, loc);
yielded.push_back(next);
scf::YieldOp::create(state.rewriter, loc, next); return success();
return loop.getResult(0); });
if (failed(loop))
return failure();
return loop->results.front();
} }
FailureOr<Value> insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state, FailureOr<Value> insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state,
@@ -2362,19 +2426,15 @@ FailureOr<Value> insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState&
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange {destination},
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body); Value acc = iterArgs.front();
Value index = loop.getInductionVar();
Value acc = body->getArgument(1);
Value channelId = createIndexedIndexValue(state, targetClass.op, run.channelIds, index, loc); Value channelId = createIndexedIndexValue(state, targetClass.op, run.channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, targetClass.op, run.sourceCoreIds, index, loc); Value sourceCoreId = createIndexedIndexValue(state, targetClass.op, run.sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, targetClass.op, run.targetCoreIds, index, loc); Value targetCoreId = createIndexedIndexValue(state, targetClass.op, run.targetCoreIds, index, loc);
@@ -2382,12 +2442,14 @@ FailureOr<Value> insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState&
Value received = Value received =
SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId)
.getOutput(); .getOutput();
Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, index, loc); Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, index, loc);
Value next = insertFragmentIntoWholeBatch(state, received, acc, outputOffset, loc); Value next = insertFragmentIntoWholeBatch(state, received, acc, outputOffset, loc);
yielded.push_back(next);
scf::YieldOp::create(state.rewriter, loc, next); return success();
return loop.getResult(0); });
if (failed(loop))
return failure();
return loop->results.front();
} }
FailureOr<Value> insertPackedScalarRunIntoWholeBatch(MaterializerState& state, FailureOr<Value> insertPackedScalarRunIntoWholeBatch(MaterializerState& state,
@@ -2444,25 +2506,24 @@ FailureOr<Value> insertPackedScalarRunIntoWholeBatch(MaterializerState& state,
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange {destination},
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value slotIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body); Value acc = iterArgs.front();
Value slotIndex = loop.getInductionVar();
Value acc = body->getArgument(1);
Value slotPacked = extractPackedSlotForIndex(state, targetClass.op, run.packed, *slotPackedType, slotIndex, loc); Value slotPacked = extractPackedSlotForIndex(state, targetClass.op, run.packed, *slotPackedType, slotIndex, loc);
Value outputOffset = createIndexedIndexValue(state, targetClass.op, slotRowOffsets, slotIndex, loc); Value outputOffset = createIndexedIndexValue(state, targetClass.op, slotRowOffsets, slotIndex, loc);
Value next = insertFragmentIntoWholeBatch(state, slotPacked, acc, outputOffset, loc); Value next = insertFragmentIntoWholeBatch(state, slotPacked, acc, outputOffset, loc);
yielded.push_back(next);
scf::YieldOp::create(state.rewriter, loc, next); return success();
return loop.getResult(0); });
if (failed(loop))
return failure();
return loop->results.front();
} }
LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state,
@@ -3055,10 +3116,11 @@ LogicalResult emitPackedRunFanout(MaterializerState& state,
Location loc) { Location loc) {
assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class");
SmallVector<ScalarSourceReceivePlan, 4> receivePlans = auto receivePlans = emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc);
emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc); if (failed(receivePlans))
return failure();
for (const ScalarSourceReceivePlan& plan : receivePlans) { for (const ScalarSourceReceivePlan& plan : *receivePlans) {
MaterializedClass& targetClass = state.classes[plan.targetClass]; MaterializedClass& targetClass = state.classes[plan.targetClass];
Value received = Value received =
@@ -3190,39 +3252,36 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange(initValues)); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange(initValues),
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body);
Value loopIndex = loop.getInductionVar();
Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc);
FailureOr<SmallVector<Value, 4>> produced = FailureOr<SmallVector<Value, 4>> produced = cloneBatchBodyForLane(
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex); state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex);
if (failed(produced)) if (failed(produced))
return failure(); return failure();
SmallVector<Value, 4> yielded;
yielded.reserve(produced->size()); yielded.reserve(produced->size());
for (auto [outputIndex, output] : llvm::enumerate(*produced)) { for (auto [outputIndex, output] : llvm::enumerate(*produced)) {
auto fragmentType = cast<RankedTensorType>(output.getType()); auto fragmentType = cast<RankedTensorType>(output.getType());
Value acc = body->getArgument(1 + outputIndex); Value acc = iterArgs[outputIndex];
Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc);
yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset));
} }
return success();
scf::YieldOp::create(state.rewriter, loc, yielded); });
if (failed(loop))
return failure();
SmallVector<Value, 4> results; SmallVector<Value, 4> results;
results.reserve(loop.getNumResults()); results.reserve(loop->results.size());
for (Value result : loop.getResults()) for (Value result : loop->results)
results.push_back(result); results.push_back(result);
return results; return results;
} }
@@ -3523,12 +3582,18 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
for ([[maybe_unused]] const MaterializationRunSlot& slot : run) { for ([[maybe_unused]] const MaterializationRunSlot& slot : run) {
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
if (failed(checkedSourceCpu))
return failure();
auto checkedTargetCpu =
getCheckedCoreId(targetClass.op,
targetClass.isBatch ? targetClass.cpus[lane] : targetClass.cpus.front(),
"batch run target core id");
if (failed(checkedTargetCpu))
return failure();
plan.channelIds.push_back(state.nextChannelId++); plan.channelIds.push_back(state.nextChannelId++);
plan.sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu)); plan.sourceCoreIds.push_back(*checkedSourceCpu);
plan.targetCoreIds.push_back(*checkedTargetCpu);
int32_t targetCpu = targetClass.isBatch ? static_cast<int32_t>(targetClass.cpus[lane])
: static_cast<int32_t>(targetClass.cpus.front());
plan.targetCoreIds.push_back(targetCpu);
} }
} }
@@ -3666,24 +3731,26 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto loop = buildNormalizedScfFor(
state.rewriter,
OpBuilder::InsertionGuard guard(state.rewriter); loc,
state.rewriter.setInsertionPointToStart(loop.getBody()); lowerBound,
upperBound,
Value slotIndex = loop.getInductionVar(); step,
ValueRange {},
[&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl<Value>&) {
Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc);
Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc);
FailureOr<SmallVector<Value, 4>> produced = FailureOr<SmallVector<Value, 4>> produced = cloneBatchBodyForLane(
cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex); state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex);
if (failed(produced)) if (failed(produced))
return failure(); return failure();
for (const BatchRunSendPlan& plan : sendPlans) { for (const BatchRunSendPlan& plan : sendPlans) {
auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); auto resultIt = llvm::find(group.resultIndices, plan.resultIndex);
if (resultIt == group.resultIndices.end()) if (resultIt == group.resultIndices.end())
return targetClass.op->emitError("internal error: missing compacted batch run result"); return failure();
size_t groupOutputIndex = static_cast<size_t>(std::distance(group.resultIndices.begin(), resultIt)); size_t groupOutputIndex = static_cast<size_t>(std::distance(group.resultIndices.begin(), resultIt));
appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc);
@@ -3691,11 +3758,15 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
for (const BatchRunSendPlan& plan : sendPlans) { for (const BatchRunSendPlan& plan : sendPlans) {
if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex])
return sourceBatch.emitOpError("failed to recover per-lane output type for compacted batch run"); return failure();
if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc)))
return failure(); return failure();
} }
return success();
});
if (failed(loop))
return failure();
} }
return success(); return success();
@@ -3754,7 +3825,7 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
return success(); return success();
} }
Value createReceiveConcatLoop(MaterializerState& state, FailureOr<Value> createReceiveConcatLoop(MaterializerState& state,
Operation* anchor, Operation* anchor,
Operation* insertionPoint, Operation* insertionPoint,
RankedTensorType concatType, RankedTensorType concatType,
@@ -3774,31 +3845,30 @@ Value createReceiveConcatLoop(MaterializerState& state,
state.rewriter.setInsertionPoint(insertionPoint); state.rewriter.setInsertionPoint(insertionPoint);
Value init = Value init =
tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult();
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); auto loop = buildNormalizedScfFor(
state.rewriter,
Block* body = loop.getBody(); loc,
if (!body->empty()) lowerBound,
if (auto yield = dyn_cast<scf::YieldOp>(body->back())) upperBound,
state.rewriter.eraseOp(yield); step,
ValueRange {init},
OpBuilder::InsertionGuard guard(state.rewriter); [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
state.rewriter.setInsertionPointToEnd(body); Value acc = iterArgs.front();
Value index = loop.getInductionVar();
Value acc = body->getArgument(1);
Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc);
Value received = Value received =
SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput(); SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId)
.getOutput();
Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc);
Value next = createDim0InsertSlice(state, loc, received, acc, firstOffset); Value next = createDim0InsertSlice(state, loc, received, acc, firstOffset);
scf::YieldOp::create(state.rewriter, loc, next); yielded.push_back(next);
return success();
return loop.getResult(0); });
if (failed(loop))
return failure();
return loop->results.front();
} }
void replaceHostUses(MaterializerState& state) { void replaceHostUses(MaterializerState& state) {
@@ -3832,7 +3902,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
if (failed(collectHostOutputs(state))) if (failed(collectHostOutputs(state)))
return failure(); return failure();
createEmptyMaterializedOps(state); if (failed(createEmptyMaterializedOps(state)))
return failure();
if (failed(collectProducerDestinations(state))) if (failed(collectProducerDestinations(state)))
return failure(); return failure();
if (failed(collectProjectedTransfers(state))) if (failed(collectProjectedTransfers(state)))
@@ -37,6 +37,7 @@
#include "Scheduling/MergeSchedulingAnalysis.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -128,8 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
} }
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
return static_cast<int32_t>(coreIdAttr.getInt()); auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
if (failed(checkedCoreId))
return std::nullopt;
return *checkedCoreId;
}
return std::nullopt; return std::nullopt;
} }
-6
View File
@@ -1,11 +1,5 @@
add_pim_library(OMPimPasses add_pim_library(OMPimPasses
MessagePass.cpp MessagePass.cpp
PimCodegen/HostConstantFolding/Common.cpp
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
PimCodegen/MaterializeHostConstantsPass.cpp
PimCodegen/VerificationPass.cpp
PimCodegen/EmitPimCodePass.cpp PimCodegen/EmitPimCodePass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
+1 -1
View File
@@ -15,7 +15,7 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass(); std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createPimStaticMemoryCoalescingPass(); std::unique_ptr<mlir::Pass> createPimMemoryCoalescingPass();
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass(); std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
+1 -1
View File
@@ -75,7 +75,7 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass); registerPass(createPimBufferizationPass);
registerPass(createPimStaticMemoryCoalescingPass); registerPass(createPimMemoryCoalescingPass);
registerPass(createMergeComputeNodesPass); registerPass(createMergeComputeNodesPass);
registerPass(createPimHostConstantFoldingPass); registerPass(createPimHostConstantFoldingPass);
registerPass(createPimMaterializeHostConstantsPass); registerPass(createPimMaterializeHostConstantsPass);