better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
This commit is contained in:
@@ -264,7 +264,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
|
||||||
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
|||||||
return numElements;
|
return numElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasByteSizedElementType(mlir::Type elementType) {
|
||||||
|
if (mlir::isa<mlir::IndexType>(elementType))
|
||||||
|
return true;
|
||||||
|
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||||
|
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
|
||||||
|
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||||
|
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||||
|
if (mlir::isa<mlir::IndexType>(elementType))
|
||||||
|
return mlir::IndexType::kInternalStorageBitWidth / 8;
|
||||||
|
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||||
|
return static_cast<size_t>(intType.getWidth() / 8);
|
||||||
|
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||||
|
return static_cast<size_t>(floatType.getWidth() / 8);
|
||||||
|
llvm_unreachable("expected byte-sized integer, float, or index element type");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
|
||||||
|
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
|
|||||||
|
|
||||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool hasByteSizedElementType(mlir::Type elementType);
|
||||||
|
|
||||||
|
size_t getElementTypeSizeInBytes(mlir::Type elementType);
|
||||||
|
|
||||||
|
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
|||||||
@@ -41,23 +41,10 @@ using namespace mlir;
|
|||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
using namespace onnx_mlir::compact_asm;
|
using namespace onnx_mlir::compact_asm;
|
||||||
|
|
||||||
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
|
||||||
if (elementType.isIndex())
|
|
||||||
return sizeof(int64_t);
|
|
||||||
if (elementType.isIntOrFloat())
|
|
||||||
return elementType.getIntOrFloatBitWidth() / 8;
|
|
||||||
llvm_unreachable("unsupported shaped element type");
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
|
||||||
auto type = cast<ShapedType>(value.getType());
|
|
||||||
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||||
auto type = cast<ShapedType>(value.getType());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
assert("Only static shape is supported" && type.hasStaticShape());
|
assert("Only static shape is supported" && type.hasStaticShape());
|
||||||
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
size_t allocSize = getShapedTypeSizeInBytes(type);
|
||||||
MemEntry memEntry = {0, allocSize};
|
MemEntry memEntry = {0, allocSize};
|
||||||
return &memEntries.emplace_back(memEntry, value).first;
|
return &memEntries.emplace_back(memEntry, value).first;
|
||||||
}
|
}
|
||||||
@@ -450,7 +437,8 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
|||||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||||
const StaticValueKnowledge& knowledge) const {
|
const StaticValueKnowledge& knowledge) const {
|
||||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||||
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
||||||
|
/ receiveTensorOp.getSourceCoreIds().size();
|
||||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||||
}
|
}
|
||||||
@@ -463,7 +451,8 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
|
|||||||
|
|
||||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||||
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
||||||
|
/ sendTensorOp.getTargetCoreIds().size();
|
||||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||||
}
|
}
|
||||||
@@ -474,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
|
|||||||
|
|
||||||
int64_t axis = concatOp.getAxis();
|
int64_t axis = concatOp.getAxis();
|
||||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
|
||||||
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
||||||
|
|
||||||
size_t outerCount = 1;
|
size_t outerCount = 1;
|
||||||
@@ -526,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -541,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -556,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -571,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -601,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
|||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 1;
|
instruction.r2OrImm = 1;
|
||||||
instruction.generic1 = 1;
|
instruction.generic1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -614,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
|||||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -627,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
|||||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -640,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
|||||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -653,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
|||||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput()));
|
instruction.generic3 =
|
||||||
|
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -666,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
|
|||||||
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||||
auto srcShape = srcType.getShape();
|
auto srcShape = srcType.getShape();
|
||||||
size_t rank = srcShape.size();
|
size_t rank = srcShape.size();
|
||||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
|
||||||
size_t totalElements = srcType.getNumElements();
|
size_t totalElements = srcType.getNumElements();
|
||||||
|
|
||||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
int64_t numCols = shape[1];
|
int64_t numCols = shape[1];
|
||||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||||
|
|
||||||
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||||
|
|
||||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||||
|
|||||||
@@ -427,11 +427,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
||||||
if (!weightArg || !inputArg)
|
if (!weightArg || !inputArg)
|
||||||
return failure();
|
return failure();
|
||||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
|
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
|
||||||
gemmLoc,
|
|
||||||
currOutHSliceType,
|
|
||||||
*weightArg,
|
|
||||||
*inputArg));
|
|
||||||
}
|
}
|
||||||
if (vmmOutputs.empty()) {
|
if (vmmOutputs.empty()) {
|
||||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
tensor::ParallelInsertSliceOp insertSlice,
|
tensor::ParallelInsertSliceOp insertSlice,
|
||||||
ShapedType destinationType,
|
ShapedType destinationType,
|
||||||
IRMapping& mapper) {
|
IRMapping& mapper) {
|
||||||
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8;
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
||||||
ArrayRef<int64_t> shape = destinationType.getShape();
|
ArrayRef<int64_t> shape = destinationType.getShape();
|
||||||
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
||||||
|
|||||||
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
|||||||
return returnValue;
|
return returnValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
|
||||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ namespace onnx_mlir {
|
|||||||
*/
|
*/
|
||||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||||
|
|
||||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
|
||||||
|
|
||||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
markOpToRemove(op);
|
markOpToRemove(op);
|
||||||
|
|
||||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
|
||||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||||
rewriter.setInsertionPointAfter(storedOp);
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||||
@@ -455,7 +455,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
|
|
||||||
if (isa<func::ReturnOp>(resultUser)) {
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||||
rewriter.setInsertionPointAfterValue(storedValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
@@ -471,7 +471,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||||
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||||
markOpToRemove(concatOp);
|
markOpToRemove(concatOp);
|
||||||
|
|
||||||
|
|||||||
@@ -325,9 +325,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
|||||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||||
Type elementType = tensorType.getElementType();
|
Type elementType = tensorType.getElementType();
|
||||||
if (!elementType.isIntOrFloat())
|
if (!hasByteSizedElementType(elementType))
|
||||||
return;
|
return;
|
||||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
|
||||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||||
|
|
||||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
|||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
||||||
|
|
||||||
return PimMemCopyOp::create(rewriter,
|
return PimMemCopyOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||||
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
|
||||||
return builder.getI32IntegerAttr(sizeInBytes);
|
return builder.getI32IntegerAttr(sizeInBytes);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool isCandidateAllocType(MemRefType type) {
|
static bool isCandidateAllocType(MemRefType type) {
|
||||||
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
|
return type && type.hasStaticShape() && type.getLayout().isIdentity()
|
||||||
|
&& hasByteSizedElementType(type.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint64_t getTypeSizeBytes(MemRefType type) {
|
static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||||
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<uint64_t>
|
static FailureOr<uint64_t>
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); }
|
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
|
||||||
|
return getBatchBodyArgument(getBody(), idx);
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||||
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
||||||
@@ -74,11 +76,13 @@ SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Locat
|
|||||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
||||||
newCompute->setAttrs((*this)->getAttrs());
|
newCompute->setAttrs((*this)->getAttrs());
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||||
newCompute.getOperation(), static_cast<int32_t>(newCompute.getWeights().size()), static_cast<int32_t>(newCompute.getInputs().size()));
|
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||||
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||||
getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
rewriter.eraseOp(getOperation());
|
rewriter.eraseOp(getOperation());
|
||||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||||
}
|
}
|
||||||
@@ -110,7 +114,8 @@ std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
|||||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
unsigned weightCount = getWeights().size();
|
unsigned weightCount = getWeights().size();
|
||||||
unsigned inputCount = getInputs().size();
|
unsigned inputCount = getInputs().size();
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
@@ -145,8 +150,9 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type,
|
|||||||
auto newBatch =
|
auto newBatch =
|
||||||
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
||||||
newBatch->setAttrs((*this)->getAttrs());
|
newBatch->setAttrs((*this)->getAttrs());
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||||
newBatch.getOperation(), static_cast<int32_t>(newBatch.getWeights().size()), static_cast<int32_t>(newBatch.getInputs().size()));
|
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||||
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||||
if (newBatch.getBody().empty()) {
|
if (newBatch.getBody().empty()) {
|
||||||
rewriter.eraseOp(newBatch);
|
rewriter.eraseOp(newBatch);
|
||||||
@@ -155,7 +161,8 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type,
|
|||||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||||
getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
rewriter.eraseOp(getOperation());
|
rewriter.eraseOp(getOperation());
|
||||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||||
}
|
}
|
||||||
|
|||||||
+305
-228
@@ -4,7 +4,6 @@
|
|||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
@@ -131,7 +130,7 @@ struct MaterializerState {
|
|||||||
DenseMap<CpuSlotKey, ComputeInstance, CpuSlotKeyInfo> cpuSlotToInstance;
|
DenseMap<CpuSlotKey, ComputeInstance, CpuSlotKeyInfo> cpuSlotToInstance;
|
||||||
DenseSet<ClassSlotKey, ClassSlotKeyInfo> materializedSlots;
|
DenseSet<ClassSlotKey, ClassSlotKeyInfo> materializedSlots;
|
||||||
|
|
||||||
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> producerDestClasses;
|
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
||||||
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> availableValues;
|
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> availableValues;
|
||||||
DenseMap<Value, Value> hostReplacements;
|
DenseMap<Value, Value> hostReplacements;
|
||||||
DenseSet<Operation*> oldComputeOps;
|
DenseSet<Operation*> oldComputeOps;
|
||||||
@@ -574,6 +573,77 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
|
|||||||
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||||
|
SmallVector<APInt, 8> elements;
|
||||||
|
elements.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
elements.push_back(APInt(64, value));
|
||||||
|
|
||||||
|
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
|
||||||
|
auto attr = DenseIntElementsAttr::get(type, elements);
|
||||||
|
return getOrCreateHostConstant(anchor, attr, type, state.constantFolder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
|
||||||
|
SmallVector<int64_t, 8> widened;
|
||||||
|
widened.reserve(values.size());
|
||||||
|
for (int32_t value : values)
|
||||||
|
widened.push_back(value);
|
||||||
|
return createIndexTensorConstant(state, anchor, ArrayRef<int64_t>(widened));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allEqual(ArrayRef<int64_t> values) {
|
||||||
|
assert(!values.empty() && "expected at least one value");
|
||||||
|
for (int64_t value : values.drop_front())
|
||||||
|
if (value != values.front())
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allEqual(ArrayRef<int32_t> values) {
|
||||||
|
assert(!values.empty() && "expected at least one value");
|
||||||
|
for (int32_t value : values.drop_front())
|
||||||
|
if (value != values.front())
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createLaneIndexedIndexValue(MaterializerState& state,
|
||||||
|
MaterializedClass& materializedClass,
|
||||||
|
ArrayRef<int64_t> values,
|
||||||
|
Location loc) {
|
||||||
|
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
|
||||||
|
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
|
||||||
|
|
||||||
|
if (allEqual(values))
|
||||||
|
return createIndexConstant(state, materializedClass.op, values.front());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(materializedClass.op);
|
||||||
|
auto laneArg = batch.getLaneArgument();
|
||||||
|
assert(laneArg && "expected compute_batch lane argument");
|
||||||
|
|
||||||
|
Value table = createIndexTensorConstant(state, materializedClass.op, values);
|
||||||
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createLaneIndexedIndexValue(MaterializerState& state,
|
||||||
|
MaterializedClass& materializedClass,
|
||||||
|
ArrayRef<int32_t> values,
|
||||||
|
Location loc) {
|
||||||
|
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
|
||||||
|
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
|
||||||
|
|
||||||
|
if (allEqual(values))
|
||||||
|
return createIndexConstant(state, materializedClass.op, values.front());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(materializedClass.op);
|
||||||
|
auto laneArg = batch.getLaneArgument();
|
||||||
|
assert(laneArg && "expected compute_batch lane argument");
|
||||||
|
|
||||||
|
Value table = createIndexTensorConstant(state, materializedClass.op, values);
|
||||||
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<ComputeInstance, 8>>
|
FailureOr<SmallVector<ComputeInstance, 8>>
|
||||||
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
||||||
SmallVector<ComputeInstance, 8> peers;
|
SmallVector<ComputeInstance, 8> peers;
|
||||||
@@ -623,14 +693,12 @@ Value createOriginalLaneValue(MaterializerState& state,
|
|||||||
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
|
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<APInt, 8> laneValues;
|
SmallVector<int64_t, 8> laneValues;
|
||||||
laneValues.reserve(peers.size());
|
laneValues.reserve(peers.size());
|
||||||
for (const ComputeInstance& peer : peers)
|
for (const ComputeInstance& peer : peers)
|
||||||
laneValues.push_back(APInt(64, peer.laneStart));
|
laneValues.push_back(peer.laneStart);
|
||||||
|
|
||||||
auto tableType = RankedTensorType::get({static_cast<int64_t>(peers.size())}, state.rewriter.getIndexType());
|
Value table = createIndexTensorConstant(state, materializedClass.op, laneValues);
|
||||||
auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues);
|
|
||||||
Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult();
|
|
||||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -659,6 +727,12 @@ bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) {
|
||||||
|
SmallVector<ClassId, 4>& destinations = state.producerDestClasses[key];
|
||||||
|
if (!llvm::is_contained(destinations, classId))
|
||||||
|
destinations.push_back(classId);
|
||||||
|
}
|
||||||
|
|
||||||
void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet<Operation*>& oldComputeOps) {
|
void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet<Operation*>& oldComputeOps) {
|
||||||
SmallVector<OpOperand*> uses;
|
SmallVector<OpOperand*> uses;
|
||||||
for (OpOperand& use : oldValue.getUses())
|
for (OpOperand& use : oldValue.getUses())
|
||||||
@@ -693,7 +767,7 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
|||||||
if (sourceClass == targetClass)
|
if (sourceClass == targetClass)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
state.producerDestClasses[producerKey].insert(targetClass);
|
appendDestinationClass(state, producerKey, targetClass);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -714,29 +788,70 @@ bool haveSameDestinationClasses(MaterializerState& state, ArrayRef<ProducerKey>
|
|||||||
return true;
|
return true;
|
||||||
|
|
||||||
auto firstIt = state.producerDestClasses.find(keys.front());
|
auto firstIt = state.producerDestClasses.find(keys.front());
|
||||||
DenseSet<ClassId> empty;
|
ArrayRef<ClassId> first = firstIt == state.producerDestClasses.end() ? ArrayRef<ClassId>() : firstIt->second;
|
||||||
const DenseSet<ClassId>& first = firstIt == state.producerDestClasses.end() ? empty : firstIt->second;
|
|
||||||
for (ProducerKey key : keys.drop_front()) {
|
for (ProducerKey key : keys.drop_front()) {
|
||||||
auto it = state.producerDestClasses.find(key);
|
auto it = state.producerDestClasses.find(key);
|
||||||
const DenseSet<ClassId>& current = it == state.producerDestClasses.end() ? empty : it->second;
|
ArrayRef<ClassId> current = it == state.producerDestClasses.end() ? ArrayRef<ClassId>() : it->second;
|
||||||
if (first.size() != current.size())
|
if (first.size() != current.size())
|
||||||
return false;
|
return false;
|
||||||
for (ClassId classId : first)
|
for (auto [lhs, rhs] : llvm::zip(first, current))
|
||||||
if (!current.contains(classId))
|
if (lhs != rhs)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<ClassId, 4> getSortedDestinationClasses(MaterializerState& state, ProducerKey key) {
|
ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey key) {
|
||||||
SmallVector<ClassId, 4> destinations;
|
|
||||||
auto it = state.producerDestClasses.find(key);
|
auto it = state.producerDestClasses.find(key);
|
||||||
if (it == state.producerDestClasses.end())
|
if (it == state.producerDestClasses.end())
|
||||||
return destinations;
|
return {};
|
||||||
for (ClassId classId : it->second)
|
return it->second;
|
||||||
destinations.push_back(classId);
|
}
|
||||||
llvm::sort(destinations);
|
|
||||||
return destinations;
|
void appendSend(MaterializerState& state,
|
||||||
|
MaterializedClass& sourceClass,
|
||||||
|
Value payload,
|
||||||
|
ArrayRef<int64_t> channelIds,
|
||||||
|
ArrayRef<int32_t> sourceCoreIds,
|
||||||
|
ArrayRef<int32_t> targetCoreIds,
|
||||||
|
Location loc) {
|
||||||
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||||
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
|
assert(!channelIds.empty() && "expected at least one send");
|
||||||
|
|
||||||
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
|
||||||
|
if (sourceClass.isBatch) {
|
||||||
|
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
|
||||||
|
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
|
||||||
|
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
|
||||||
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
|
||||||
|
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
|
||||||
|
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
|
||||||
|
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
|
||||||
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value appendScalarReceive(MaterializerState& state,
|
||||||
|
MaterializedClass& targetClass,
|
||||||
|
Type type,
|
||||||
|
int64_t channelId,
|
||||||
|
int32_t sourceCoreId,
|
||||||
|
int32_t targetCoreId,
|
||||||
|
Location loc) {
|
||||||
|
assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class");
|
||||||
|
|
||||||
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
|
Value channelIdValue = createIndexConstant(state, targetClass.op, channelId);
|
||||||
|
Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId);
|
||||||
|
Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId);
|
||||||
|
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue)
|
||||||
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value appendReceive(MaterializerState& state,
|
Value appendReceive(MaterializerState& state,
|
||||||
@@ -746,50 +861,169 @@ Value appendReceive(MaterializerState& state,
|
|||||||
ArrayRef<int32_t> sourceCoreIds,
|
ArrayRef<int32_t> sourceCoreIds,
|
||||||
ArrayRef<int32_t> targetCoreIds,
|
ArrayRef<int32_t> targetCoreIds,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||||
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
|
assert(!channelIds.empty() && "expected at least one receive");
|
||||||
|
|
||||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, targetClass.op, channelIds);
|
|
||||||
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds);
|
|
||||||
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, targetClass.op, targetCoreIds);
|
|
||||||
|
|
||||||
if (targetClass.isBatch) {
|
if (targetClass.isBatch) {
|
||||||
return SpatChannelReceiveBatchOp::create(
|
Value channelId = createLaneIndexedIndexValue(state, targetClass, channelIds, loc);
|
||||||
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
|
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, sourceCoreIds, loc);
|
||||||
.getOutput();
|
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, targetCoreIds, loc);
|
||||||
|
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (channelIds.size() != 1) {
|
assert(channelIds.size() == 1 && "scalar target class can only receive one message at a time");
|
||||||
return SpatChannelReceiveTensorOp::create(
|
return appendScalarReceive(
|
||||||
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
|
state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
|
||||||
.getOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
return SpatChannelReceiveOp::create(
|
|
||||||
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
|
||||||
.getOutput();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value appendHostReceive(MaterializerState& state,
|
Value appendPackedScalarReceives(MaterializerState& state,
|
||||||
MaterializedClass& sourceClass,
|
MaterializedClass& targetClass,
|
||||||
Type type,
|
Type fragmentType,
|
||||||
ArrayRef<int64_t> channelIds,
|
Type packedType,
|
||||||
ArrayRef<int32_t> sourceCoreIds,
|
ArrayRef<int64_t> channelIds,
|
||||||
ArrayRef<int32_t> targetCoreIds,
|
ArrayRef<int32_t> sourceCoreIds,
|
||||||
Location loc) {
|
ArrayRef<int32_t> targetCoreIds,
|
||||||
state.rewriter.setInsertionPointAfter(sourceClass.op);
|
Location loc) {
|
||||||
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
|
assert(!targetClass.isBatch && "packed scalar receive helper expects a scalar target class");
|
||||||
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||||
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
|
assert(!channelIds.empty() && "expected at least one receive");
|
||||||
|
|
||||||
if (sourceClass.isBatch) {
|
SmallVector<Value, 8> fragments;
|
||||||
return SpatChannelReceiveTensorOp::create(
|
fragments.reserve(channelIds.size());
|
||||||
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
|
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
|
||||||
.getOutput();
|
fragments.push_back(appendScalarReceive(
|
||||||
|
state, targetClass, fragmentType, channelIds[index], sourceCoreIds[index], targetCoreIds[index], loc));
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
return SpatChannelReceiveOp::create(
|
|
||||||
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
Value packed = fragments.front();
|
||||||
.getOutput();
|
if (fragments.size() != 1)
|
||||||
|
packed = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult();
|
||||||
|
|
||||||
|
if (packed.getType() != packedType)
|
||||||
|
packed = tensor::CastOp::create(state.rewriter, loc, packedType, packed).getResult();
|
||||||
|
|
||||||
|
return packed;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||||
|
MaterializedClass& sourceClass,
|
||||||
|
MaterializedClass& targetClass,
|
||||||
|
ArrayRef<ProducerKey> keys,
|
||||||
|
Value payload,
|
||||||
|
Location loc) {
|
||||||
|
if (sourceClass.id == targetClass.id) {
|
||||||
|
for (ProducerKey key : keys)
|
||||||
|
state.availableValues[key][targetClass.id] = payload;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sourceClass.isBatch && !targetClass.isBatch) {
|
||||||
|
int64_t channelId = state.nextChannelId++;
|
||||||
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
||||||
|
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
||||||
|
|
||||||
|
SmallVector<int64_t, 1> channelIds {channelId};
|
||||||
|
SmallVector<int32_t, 1> sourceCoreIds {sourceCpu};
|
||||||
|
SmallVector<int32_t, 1> targetCoreIds {targetCpu};
|
||||||
|
|
||||||
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
Value received =
|
||||||
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
|
for (ProducerKey key : keys)
|
||||||
|
state.availableValues[key][targetClass.id] = received;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sourceClass.isBatch && targetClass.isBatch) {
|
||||||
|
SmallVector<int64_t, 8> channelIds;
|
||||||
|
SmallVector<int32_t, 8> sourceCoreIds;
|
||||||
|
SmallVector<int32_t, 8> targetCoreIds;
|
||||||
|
channelIds.reserve(targetClass.cpus.size());
|
||||||
|
sourceCoreIds.reserve(targetClass.cpus.size());
|
||||||
|
targetCoreIds.reserve(targetClass.cpus.size());
|
||||||
|
|
||||||
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
||||||
|
for (CpuId targetCpu : targetClass.cpus) {
|
||||||
|
channelIds.push_back(state.nextChannelId++);
|
||||||
|
sourceCoreIds.push_back(sourceCpu);
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
|
||||||
|
}
|
||||||
|
|
||||||
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
Value received =
|
||||||
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
|
for (ProducerKey key : keys)
|
||||||
|
state.availableValues[key][targetClass.id] = received;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sourceClass.isBatch && !targetClass.isBatch) {
|
||||||
|
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
||||||
|
if (!packedKey)
|
||||||
|
return sourceClass.op->emitError(
|
||||||
|
"cannot materialize batch-to-scalar communication because source lanes are not contiguous");
|
||||||
|
|
||||||
|
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
||||||
|
if (failed(packedType))
|
||||||
|
return sourceClass.op->emitError(
|
||||||
|
"cannot materialize batch-to-scalar communication for non-static ranked tensor payload");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> channelIds;
|
||||||
|
SmallVector<int32_t, 8> sourceCoreIds;
|
||||||
|
SmallVector<int32_t, 8> targetCoreIds;
|
||||||
|
channelIds.reserve(sourceClass.cpus.size());
|
||||||
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
||||||
|
targetCoreIds.reserve(sourceClass.cpus.size());
|
||||||
|
|
||||||
|
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
||||||
|
for (CpuId sourceCpu : sourceClass.cpus) {
|
||||||
|
channelIds.push_back(state.nextChannelId++);
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
||||||
|
targetCoreIds.push_back(targetCpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
Value received = appendPackedScalarReceives(
|
||||||
|
state, targetClass, payload.getType(), *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
|
state.availableValues[*packedKey][targetClass.id] = received;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sourceClass.isBatch && targetClass.isBatch) {
|
||||||
|
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
||||||
|
return sourceClass.op->emitError(
|
||||||
|
"cannot materialize batch communication between equivalence classes of different sizes");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> channelIds;
|
||||||
|
SmallVector<int32_t, 8> sourceCoreIds;
|
||||||
|
SmallVector<int32_t, 8> targetCoreIds;
|
||||||
|
channelIds.reserve(sourceClass.cpus.size());
|
||||||
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
||||||
|
targetCoreIds.reserve(targetClass.cpus.size());
|
||||||
|
|
||||||
|
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
||||||
|
channelIds.push_back(state.nextChannelId++);
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
|
||||||
|
}
|
||||||
|
|
||||||
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
Value received =
|
||||||
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
|
|
||||||
|
for (ProducerKey key : keys)
|
||||||
|
state.availableValues[key][targetClass.id] = received;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -821,207 +1055,50 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
|
|||||||
if (!payloadType || !payloadType.hasStaticShape())
|
if (!payloadType || !payloadType.hasStaticShape())
|
||||||
return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor");
|
return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor");
|
||||||
|
|
||||||
|
auto laneArg = batch.getLaneArgument();
|
||||||
|
if (!laneArg)
|
||||||
|
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
|
||||||
|
|
||||||
|
auto outputArg = batch.getOutputArgument(resultIndex);
|
||||||
|
if (!outputArg)
|
||||||
|
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
||||||
|
|
||||||
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
|
|
||||||
SmallVector<OpFoldResult, 4> offsets;
|
SmallVector<OpFoldResult, 4> offsets;
|
||||||
SmallVector<OpFoldResult, 4> sizes;
|
SmallVector<OpFoldResult, 4> sizes;
|
||||||
SmallVector<OpFoldResult, 4> strides;
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
offsets.reserve(payloadType.getRank());
|
offsets.reserve(payloadType.getRank());
|
||||||
sizes.reserve(payloadType.getRank());
|
sizes.reserve(payloadType.getRank());
|
||||||
strides.reserve(payloadType.getRank());
|
strides.reserve(payloadType.getRank());
|
||||||
auto laneArg = batch.getLaneArgument();
|
|
||||||
if (!laneArg)
|
|
||||||
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
|
|
||||||
offsets.push_back(*laneArg);
|
offsets.push_back(*laneArg);
|
||||||
sizes.push_back(state.rewriter.getIndexAttr(1));
|
sizes.push_back(state.rewriter.getIndexAttr(1));
|
||||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
|
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
|
||||||
offsets.push_back(state.rewriter.getIndexAttr(0));
|
offsets.push_back(state.rewriter.getIndexAttr(0));
|
||||||
sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim)));
|
sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim)));
|
||||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outputArg = batch.getOutputArgument(resultIndex);
|
|
||||||
if (!outputArg)
|
|
||||||
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
|
||||||
|
|
||||||
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
|
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void appendScalarSend(MaterializerState& state,
|
|
||||||
MaterializedClass& sourceClass,
|
|
||||||
Value payload,
|
|
||||||
int64_t channelId,
|
|
||||||
int32_t sourceCoreId,
|
|
||||||
int32_t targetCoreId,
|
|
||||||
Location loc) {
|
|
||||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
|
||||||
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
|
|
||||||
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
|
|
||||||
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
|
|
||||||
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
|
|
||||||
}
|
|
||||||
|
|
||||||
void appendBatchSend(MaterializerState& state,
|
|
||||||
MaterializedClass& sourceClass,
|
|
||||||
Value payload,
|
|
||||||
ArrayRef<int64_t> channelIds,
|
|
||||||
ArrayRef<int32_t> sourceCoreIds,
|
|
||||||
ArrayRef<int32_t> targetCoreIds,
|
|
||||||
Location loc) {
|
|
||||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
|
||||||
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
|
|
||||||
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
|
|
||||||
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
|
|
||||||
SpatChannelSendBatchOp::create(state.rewriter, loc, channelIdValues, sourceCoreIdValues, targetCoreIdValues, payload);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|
||||||
MaterializedClass& sourceClass,
|
|
||||||
MaterializedClass& targetClass,
|
|
||||||
ArrayRef<ProducerKey> keys,
|
|
||||||
Value payload,
|
|
||||||
Location loc) {
|
|
||||||
if (sourceClass.id == targetClass.id) {
|
|
||||||
for (ProducerKey key : keys)
|
|
||||||
state.availableValues[key][targetClass.id] = payload;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sourceClass.isBatch && !targetClass.isBatch) {
|
|
||||||
int64_t channelId = state.nextChannelId++;
|
|
||||||
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
|
||||||
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
|
||||||
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
|
|
||||||
Value received = appendReceive(state,
|
|
||||||
targetClass,
|
|
||||||
payload.getType(),
|
|
||||||
ArrayRef<int64_t>(channelId),
|
|
||||||
ArrayRef<int32_t>(sourceCpu),
|
|
||||||
ArrayRef<int32_t>(targetCpu),
|
|
||||||
loc);
|
|
||||||
for (ProducerKey key : keys)
|
|
||||||
state.availableValues[key][targetClass.id] = received;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!sourceClass.isBatch && targetClass.isBatch) {
|
|
||||||
SmallVector<int64_t, 8> channelIds;
|
|
||||||
SmallVector<int32_t, 8> sourceCoreIds;
|
|
||||||
SmallVector<int32_t, 8> targetCoreIds;
|
|
||||||
channelIds.reserve(targetClass.cpus.size());
|
|
||||||
sourceCoreIds.reserve(targetClass.cpus.size());
|
|
||||||
targetCoreIds.reserve(targetClass.cpus.size());
|
|
||||||
|
|
||||||
for (CpuId targetCpu : targetClass.cpus) {
|
|
||||||
int64_t channelId = state.nextChannelId++;
|
|
||||||
channelIds.push_back(channelId);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceClass.cpus.front()));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
|
|
||||||
appendScalarSend(state,
|
|
||||||
sourceClass,
|
|
||||||
payload,
|
|
||||||
channelId,
|
|
||||||
static_cast<int32_t>(sourceClass.cpus.front()),
|
|
||||||
static_cast<int32_t>(targetCpu),
|
|
||||||
loc);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value received =
|
|
||||||
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
for (ProducerKey key : keys)
|
|
||||||
state.availableValues[key][targetClass.id] = received;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sourceClass.isBatch && !targetClass.isBatch) {
|
|
||||||
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
|
||||||
if (!packedKey)
|
|
||||||
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
|
|
||||||
"lanes are not contiguous in send order");
|
|
||||||
|
|
||||||
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
|
||||||
if (failed(packedType))
|
|
||||||
return sourceClass.op->emitError(
|
|
||||||
"cannot materialize batch-to-scalar communication as concat for non-static ranked tensor payload");
|
|
||||||
|
|
||||||
SmallVector<int64_t, 8> channelIds;
|
|
||||||
SmallVector<int32_t, 8> sourceCoreIds;
|
|
||||||
SmallVector<int32_t, 8> targetCoreIds;
|
|
||||||
channelIds.reserve(sourceClass.cpus.size());
|
|
||||||
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
||||||
targetCoreIds.reserve(sourceClass.cpus.size());
|
|
||||||
|
|
||||||
for (CpuId sourceCpu : sourceClass.cpus) {
|
|
||||||
channelIds.push_back(state.nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus.front()));
|
|
||||||
}
|
|
||||||
|
|
||||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
Value received = appendReceive(state, targetClass, *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
state.availableValues[*packedKey][targetClass.id] = received;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sourceClass.isBatch && targetClass.isBatch) {
|
|
||||||
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
|
||||||
return sourceClass.op->emitError(
|
|
||||||
"cannot materialize batch communication between equivalence classes of different sizes");
|
|
||||||
|
|
||||||
SmallVector<int64_t, 8> channelIds;
|
|
||||||
SmallVector<int32_t, 8> sourceCoreIds;
|
|
||||||
SmallVector<int32_t, 8> targetCoreIds;
|
|
||||||
channelIds.reserve(sourceClass.cpus.size());
|
|
||||||
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
||||||
targetCoreIds.reserve(targetClass.cpus.size());
|
|
||||||
|
|
||||||
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
|
||||||
channelIds.push_back(state.nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
|
|
||||||
}
|
|
||||||
|
|
||||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
Value received =
|
|
||||||
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
for (ProducerKey key : keys)
|
|
||||||
state.availableValues[key][targetClass.id] = received;
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
return sourceClass.op->emitError("unhandled materialized communication pattern");
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult emitHostCommunication(MaterializerState& state,
|
LogicalResult emitHostCommunication(MaterializerState& state,
|
||||||
MaterializedClass& sourceClass,
|
MaterializedClass& sourceClass,
|
||||||
ArrayRef<ProducerKey> keys,
|
ArrayRef<ProducerKey> keys,
|
||||||
Value payload,
|
Value payload,
|
||||||
Value originalOutput,
|
Value originalOutput,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
|
(void) keys;
|
||||||
|
(void) loc;
|
||||||
|
|
||||||
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
if (!sourceClass.hostOutputs.empty())
|
return setHostOutputValue(state, sourceClass, originalOutput, payload);
|
||||||
return setHostOutputValue(state, sourceClass, originalOutput, payload);
|
|
||||||
|
|
||||||
SmallVector<int64_t, 8> channelIds;
|
|
||||||
SmallVector<int32_t, 8> sourceCoreIds;
|
|
||||||
SmallVector<int32_t, 8> targetCoreIds;
|
|
||||||
channelIds.reserve(sourceClass.cpus.size());
|
|
||||||
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
||||||
targetCoreIds.reserve(sourceClass.cpus.size());
|
|
||||||
for (CpuId sourceCpu : sourceClass.cpus) {
|
|
||||||
channelIds.push_back(state.nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
||||||
targetCoreIds.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
Value received =
|
|
||||||
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
||||||
state.hostReplacements[originalOutput] = received;
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult emitOutputFanout(MaterializerState& state,
|
LogicalResult emitOutputFanout(MaterializerState& state,
|
||||||
@@ -1034,7 +1111,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
return success();
|
return success();
|
||||||
|
|
||||||
if (!sourceClass.isBatch) {
|
if (!sourceClass.isBatch) {
|
||||||
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
|
||||||
if (failed(
|
if (failed(
|
||||||
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -1048,7 +1125,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
return sourceClass.op->emitError(
|
return sourceClass.op->emitError(
|
||||||
"cannot materialize batched output whose lanes have different destination equivalence classes");
|
"cannot materialize batched output whose lanes have different destination equivalence classes");
|
||||||
|
|
||||||
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
|
||||||
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|||||||
@@ -660,12 +660,12 @@ public:
|
|||||||
|
|
||||||
emitMergeIrCounts("after-materialization", func);
|
emitMergeIrCounts("after-materialization", func);
|
||||||
|
|
||||||
if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
|
/*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
emitMergeIrCounts("after-post-merge-compaction", func);
|
emitMergeIrCounts("after-post-merge-compaction", func);*/
|
||||||
|
|
||||||
{
|
{
|
||||||
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
|
|
||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = getShapedTypeSizeInBytes(initType);
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
mapOp.getLoc(),
|
mapOp.getLoc(),
|
||||||
initType,
|
initType,
|
||||||
|
|||||||
@@ -176,9 +176,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|||||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
if (!hasByteSizedElementType(sourceType.getElementType()))
|
||||||
if (elementByteWidth <= 0)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
|
|
||||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||||
if (size != totalBytes)
|
if (size != totalBytes)
|
||||||
|
|||||||
@@ -31,13 +31,6 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t getValueSizeInBytes(Value value) {
|
|
||||||
auto type = dyn_cast<ShapedType>(value.getType());
|
|
||||||
if (!type || !type.hasStaticShape())
|
|
||||||
return -1;
|
|
||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename CoreOpTy>
|
template <typename CoreOpTy>
|
||||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||||
IRRewriter& rewriter,
|
IRRewriter& rewriter,
|
||||||
@@ -82,7 +75,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
int64_t totalBytes = -1;
|
||||||
|
if (auto type = dyn_cast<ShapedType>(originalValue.getType()); type && type.hasStaticShape())
|
||||||
|
totalBytes = static_cast<int64_t>(getShapedTypeSizeInBytes(type));
|
||||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||||
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
|
|||||||
Reference in New Issue
Block a user