fix much stuff

This commit is contained in:
NiccoloN
2026-05-22 18:53:38 +02:00
parent 8337a11ce9
commit 2c1da813b5
18 changed files with 502 additions and 191 deletions
+18 -6
View File
@@ -43,8 +43,14 @@ def SpatCompute : SpatOp<"compute",
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
let hasVerifier = 1;
@@ -70,10 +76,16 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
::mlir::BlockArgument getOutputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
let hasVerifier = 1;
+147 -17
View File
@@ -6,11 +6,81 @@ using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
if (body.empty())
return std::nullopt;
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
return getBody().front().getArgument(getWeights().size() + idx);
Block& block = body.front();
if (argIdx >= block.getNumArguments())
return std::nullopt;
return block.getArgument(argIdx);
}
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
if (body.empty())
return std::nullopt;
return body.insertArgument(argIdx, type, loc);
}
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); }
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
FailureOr<std::tuple<OpResult, SpatCompute>>
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(
newCompute.getOperation(), static_cast<int32_t>(newCompute.getWeights().size()), static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
@@ -18,42 +88,102 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
}
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + idx);
}
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(
newBatch.getOperation(), static_cast<int32_t>(newBatch.getWeights().size()), static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
continue;
if (index == 0) {
setNameFn(getOutputArgument(index), "out");
setNameFn(*outputArg, "out");
continue;
}
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
}
}
+3
View File
@@ -5,12 +5,15 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include <map>
#include <optional>
#include <string>
#include <tuple>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
+48 -20
View File
@@ -218,17 +218,26 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
}
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
@@ -309,29 +318,48 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
auto laneArg = getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
if (getNumResults() != 0) {
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(getLaneArgument());
printer.printOperand(*laneArg);
printer << " = 0 to " << getLaneCount();
printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
SmallVector<BlockArgument> outputArgs;
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index)
outputArgs.push_back(getOutputArgument(index));
printBlockArgumentList(printer, outputArgs);
}
+35 -19
View File
@@ -107,8 +107,11 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
return false;
unsigned argNumber = blockArg.getArgNumber();
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
auto firstOutputArg = batchOp.getOutputArgument(0);
if (!firstOutputArg)
return false;
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
}
static bool isConstantIndexLike(Value value) {
@@ -293,10 +296,12 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
}
BlockArgument laneArg = batchOp.getLaneArgument();
auto laneArg = batchOp.getLaneArgument();
if (!laneArg)
return batchOp.emitError("compute_batch body must have a lane block argument");
for (auto& bodyOp : block) {
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
return failure();
}
return success();
@@ -457,12 +462,16 @@ LogicalResult SpatCompute::verify() {
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
if (getInputArgument(inputIndex).getType() != input.getType())
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
@@ -497,7 +506,7 @@ LogicalResult SpatCompute::verify() {
}
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (getInputArgument(inputIndex).use_empty())
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return emitError("ComputeOp block argument is not used");
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
return failure();
@@ -574,23 +583,28 @@ LogicalResult SpatComputeBatch::verify() {
}
Block& block = getBody().front();
if (block.getNumArguments() == 0)
return emitError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = getLaneArgument();
if (!laneArg || !laneArg->getType().isIndex())
return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
BlockArgument blockArg = getInputArgument(inputIndex);
if (blockArg.getType() != input.getType())
auto blockArg = getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly");
}
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
BlockArgument blockArg = getOutputArgument(resultIndex);
if (blockArg.getType() != resultType)
auto blockArg = getOutputArgument(resultIndex);
if (!blockArg || blockArg->getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
}
@@ -608,13 +622,15 @@ LogicalResult SpatInParallelOp::verify() {
if (batchOp.getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
BlockArgument laneArg = batchOp.getLaneArgument();
auto laneArg = batchOp.getLaneArgument();
if (!laneArg)
return emitOpError("expected compute_batch lane block argument");
for (Operation& op : getRegion().front().getOperations()) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSliceOp)
return emitOpError("expected only tensor.parallel_insert_slice ops");
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
return failure();
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
@@ -432,15 +432,6 @@ LogicalResult collectHostOutputs(MaterializerState& state) {
return success();
}
void setOperandSegmentSizes(Operation* op, int weightCount, int inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
auto batch = cast<SpatComputeBatch>(op);
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
void createEmptyMaterializedOps(MaterializerState& state) {
Location loc = state.func.getLoc();
Block& funcBlock = state.func.getBody().front();
@@ -529,19 +520,17 @@ BlockArgument appendWeight(MaterializerState& state, MaterializedClass& material
materializedClass.weights.push_back(weight);
if (auto compute = dyn_cast<SpatCompute>(materializedClass.op)) {
compute.getWeightsMutable().append(ValueRange(weight));
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
BlockArgument arg = materializedClass.body->insertArgument(weightIndex, weight.getType(), weight.getLoc());
materializedClass.weightArgs[weight] = arg;
return arg;
auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc());
assert(arg && "expected compute body while inserting a weight");
materializedClass.weightArgs[weight] = std::get<1>(*arg);
return std::get<1>(*arg);
}
auto batch = cast<SpatComputeBatch>(materializedClass.op);
batch.getWeightsMutable().append(ValueRange(weight));
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
BlockArgument arg = materializedClass.body->insertArgument(1 + weightIndex, weight.getType(), weight.getLoc());
materializedClass.weightArgs[weight] = arg;
return arg;
auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc());
assert(arg && "expected compute_batch body while inserting a weight argument");
materializedClass.weightArgs[weight] = std::get<1>(*arg);
return std::get<1>(*arg);
}
BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) {
@@ -551,17 +540,16 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
materializedClass.inputs.push_back(input);
if (auto compute = dyn_cast<SpatCompute>(materializedClass.op)) {
compute.getInputsMutable().append(ValueRange(input));
BlockArgument arg = materializedClass.body->addArgument(input.getType(), input.getLoc());
materializedClass.inputArgs[input] = arg;
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
assert(arg && "expected compute body while inserting an input");
materializedClass.inputArgs[input] = std::get<1>(*arg);
return std::get<1>(*arg);
}
else {
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
BlockArgument arg = materializedClass.body->insertArgument(
materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc());
materializedClass.inputArgs[input] = arg;
return arg;
if (auto compute = dyn_cast<SpatComputeBatch>(materializedClass.op)) {
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
assert(arg && "expected compute_batch body while inserting an input argument");
materializedClass.inputArgs[input] = std::get<1>(*arg);
return std::get<1>(*arg);
}
llvm_unreachable("Cannot reach here");
}
@@ -608,6 +596,8 @@ Value createOriginalLaneValue(MaterializerState& state,
return createIndexConstant(state, materializedClass.op, peers.front().laneStart);
auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected materialized compute_batch lane argument");
bool identity = true;
for (auto [lane, peer] : llvm::enumerate(peers)) {
if (peer.laneCount != 1 || peer.laneStart != lane) {
@@ -616,7 +606,7 @@ Value createOriginalLaneValue(MaterializerState& state,
}
}
if (identity)
return batch.getLaneArgument();
return *laneArg;
bool affineWithBase = true;
int64_t base = static_cast<int64_t>(peers.front().laneStart);
@@ -628,9 +618,9 @@ Value createOriginalLaneValue(MaterializerState& state,
}
if (affineWithBase) {
if (base == 0)
return batch.getLaneArgument();
return *laneArg;
Value baseValue = createIndexConstant(state, materializedClass.op, base);
return arith::AddIOp::create(state.rewriter, loc, batch.getLaneArgument(), baseValue).getResult();
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
}
SmallVector<APInt, 8> laneValues;
@@ -641,7 +631,7 @@ Value createOriginalLaneValue(MaterializerState& state,
auto tableType = RankedTensorType::get({static_cast<int64_t>(peers.size())}, state.rewriter.getIndexType());
auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues);
Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult();
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {batch.getLaneArgument()}).getResult();
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps) {
@@ -838,7 +828,10 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
offsets.reserve(payloadType.getRank());
sizes.reserve(payloadType.getRank());
strides.reserve(payloadType.getRank());
offsets.push_back(batch.getLaneArgument());
auto laneArg = batch.getLaneArgument();
if (!laneArg)
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
offsets.push_back(*laneArg);
sizes.push_back(state.rewriter.getIndexAttr(1));
strides.push_back(state.rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
@@ -847,8 +840,11 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
strides.push_back(state.rewriter.getIndexAttr(1));
}
tensor::ParallelInsertSliceOp::create(
state.rewriter, payload.getLoc(), payload, batch.getOutputArgument(resultIndex), offsets, sizes, strides);
auto outputArg = batch.getOutputArgument(resultIndex);
if (!outputArg)
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
return success();
}
@@ -1136,14 +1132,20 @@ void mapWeights(MaterializerState& state,
IRMapping& mapper) {
Operation* op = instance.op;
if (auto compute = dyn_cast<SpatCompute>(op)) {
for (auto [index, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(index), appendWeight(state, targetClass, weight));
for (auto [index, weight] : llvm::enumerate(compute.getWeights())) {
auto weightArg = compute.getWeightArgument(index);
assert(weightArg && "expected compute weight block argument");
mapper.map(*weightArg, appendWeight(state, targetClass, weight));
}
return;
}
auto batch = cast<SpatComputeBatch>(op);
for (auto [index, weight] : llvm::enumerate(batch.getWeights()))
mapper.map(batch.getWeightArgument(index), appendWeight(state, targetClass, weight));
for (auto [index, weight] : llvm::enumerate(batch.getWeights())) {
auto weightArg = batch.getWeightArgument(index);
assert(weightArg && "expected compute_batch weight block argument");
mapper.map(*weightArg, appendWeight(state, targetClass, weight));
}
}
LogicalResult mapInputs(MaterializerState& state,
@@ -1156,7 +1158,10 @@ LogicalResult mapInputs(MaterializerState& state,
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
if (failed(mapped))
return compute.emitOpError("failed to resolve materialized compute input");
mapper.map(compute.getInputArgument(index), *mapped);
auto inputArg = compute.getInputArgument(index);
if (!inputArg)
return compute.emitOpError("expected compute input block argument while materializing inputs");
mapper.map(*inputArg, *mapped);
}
return success();
}
@@ -1166,7 +1171,10 @@ LogicalResult mapInputs(MaterializerState& state,
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
if (failed(mapped))
return batch.emitOpError("failed to resolve materialized compute_batch input");
mapper.map(batch.getInputArgument(index), *mapped);
auto inputArg = batch.getInputArgument(index);
if (!inputArg)
return batch.emitOpError("expected compute_batch input block argument while materializing inputs");
mapper.map(*inputArg, *mapped);
}
return success();
}
@@ -1186,8 +1194,10 @@ SmallVector<Value, 4> collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin
if (!outputArg || outputArg.getOwner() != &batch.getBody().front())
continue;
unsigned firstOutputArg = batch.getOutputArgument(0).getArgNumber();
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
auto firstOutputArg = batch.getOutputArgument(0);
if (!firstOutputArg)
return outputs;
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= outputs.size())
continue;
outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource());
@@ -1217,7 +1227,12 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra
return failure();
}
}
mapper.map(batch.getLaneArgument(), createOriginalLaneValue(state, targetClass, peers, loc));
auto laneArg = batch.getLaneArgument();
if (!laneArg) {
sourceOp->emitError("expected source compute_batch lane block argument");
return failure();
}
mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc));
}
mapWeights(state, targetClass, instance, mapper);
@@ -223,18 +223,32 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
newBody->addArgument(input.getType(), loc);
IRMapping mapper;
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs()))
mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex));
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]));
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
assert(oldWeightArg && newWeightArg && "expected compute weight block arguments");
mapper.map(*oldWeightArg, *newWeightArg);
}
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) {
auto oldInputArg = compute.getInputArgument(inputIndex);
auto newInputArg = newCompute.getInputArgument(inputIndex);
assert(oldInputArg && newInputArg && "expected compute input block arguments");
mapper.map(*oldInputArg, *newInputArg);
}
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) {
auto oldWeightArg = child.getWeightArgument(oldIndex);
auto newWeightArg = newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]);
assert(oldWeightArg && newWeightArg && "expected child compute weight block arguments");
mapper.map(*oldWeightArg, *newWeightArg);
}
rewriter.setInsertionPointToEnd(newBody);
auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
for (Operation& op : compute.getBody().front().without_terminator())
rewriter.clone(op, mapper);
mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
auto childInputArg = child.getInputArgument(childInputIndex);
assert(childInputArg && "expected child compute input block argument");
mapper.map(*childInputArg, mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
rewriter.setInsertionPointToEnd(newBody);
for (auto& op : child.getBody().front())