This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
+121 -12
View File
@@ -26,7 +26,7 @@ def SpatTensor :
// Execution
//===----------------------------------------------------------------------===//
def SpatCompute : SpatOp<"compute",
class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights";
@@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute",
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
@@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute",
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatComputeBatch : SpatOp<"compute_batch",
def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
@@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
@@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch",
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
def SpatInParallelOp : SpatOp<"in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"SpatComputeBatch">,
] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful spat.compute_batch";
let summary = "Parallel combining terminator for resultful Spatial compute batches";
let regions = (region SizedRegion<1>:$region);
@@ -159,6 +192,82 @@ def SpatConcatOp : SpatOp<"concat", []> {
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Planning
//===----------------------------------------------------------------------===//
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
let arguments = (ins
SpatTensor:$input,
SpatTensor:$weight,
Optional<SpatTensor>:$bias,
DenseI64ArrayAttr:$pads,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$dilations,
I64Attr:$group,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
let summary = "Layout-aware ReLU planning op";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
let summary = "Passive logical-to-physical layout selection record";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$physicalLayout,
DenseI64ArrayAttr:$fragmentOffsets,
DenseI64ArrayAttr:$fragmentSizes,
StrAttr:$indexMap
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
let summary = "Explicit layout conversion or materialization barrier";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$sourcePhysicalLayout,
StrAttr:$targetPhysicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
+235 -118
View File
@@ -29,11 +29,19 @@ std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx,
}
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) {
if (auto compute = dyn_cast<SpatGraphCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
if (auto compute = dyn_cast<SpatScheduledCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
if (auto batch = dyn_cast<SpatGraphComputeBatch>(op)) {
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatScheduledComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
@@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) {
return weights;
}
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), getWeights().size() + idx);
template <typename ComputeOpTy>
std::optional<BlockArgument> getComputeWeightArgument(ComputeOpTy compute, unsigned idx) {
return getBlockArgument(compute.getBody(), idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
template <typename ComputeOpTy>
std::optional<BlockArgument> getComputeInputArgument(ComputeOpTy compute, unsigned idx) {
return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx);
}
template <typename ComputeOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) {
auto index = std::distance(compute.getWeights().begin(), existing);
return {{*existing, *getComputeWeightArgument(compute, index)}};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
unsigned weightCount = compute.getWeights().size();
unsigned inputCount = compute.getInputs().size();
compute.getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
compute.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
template <typename ComputeBatchOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) {
auto index = std::distance(batch.getWeights().begin(), existing);
return {{*existing, *batch.getWeightArgument(index)}};
}
unsigned weightCount = batch.getWeights().size();
unsigned inputCount = batch.getInputs().size();
batch.getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
batch.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg);
}
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatCompute>>
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
template <typename ComputeOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) {
unsigned weightCount = compute.getWeights().size();
unsigned inputCount = compute.getInputs().size();
compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
compute.getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
template <typename ComputeOpTy>
void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index))
for (unsigned index = 0; index < compute.getWeights().size(); ++index)
if (auto weightArg = compute.getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index))
for (unsigned index = 0; index < compute.getInputs().size(); ++index)
if (auto inputArg = compute.getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
}
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
template <typename ComputeOpTy>
FailureOr<std::tuple<OpResult, ComputeOpTy>>
insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > compute.getNumResults())
return failure();
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
rewriter.setInsertionPoint(compute.getOperation());
SmallVector<Type> resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute =
ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs());
newCompute->setAttrs(compute->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx)
compute.getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(compute.getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
}
template <typename ComputeBatchOpTy>
FailureOr<std::tuple<OpResult, BlockArgument, ComputeBatchOpTy>>
insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > batch.getNumResults())
return failure();
rewriter.setInsertionPoint(batch.getOperation());
SmallVector<Type> resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs());
newBatch->setAttrs(batch->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx)
batch.getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(batch.getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
}
} // namespace
bool isGraphComputeLike(Operation* op) { return isa<SpatGraphCompute, SpatGraphComputeBatch>(op); }
bool isGraphBatchComputeLike(Operation* op) { return isa<SpatGraphComputeBatch>(op); }
bool isScheduledComputeLike(Operation* op) { return isa<SpatScheduledCompute, SpatScheduledComputeBatch>(op); }
bool isScheduledBatchComputeLike(Operation* op) { return isa<SpatScheduledComputeBatch>(op); }
bool isAnySpatialComputeLike(Operation* op) {
return isa<SpatGraphCompute, SpatGraphComputeBatch, SpatScheduledCompute, SpatScheduledComputeBatch>(op);
}
bool isAnySpatialComputeBatchLike(Operation* op) { return isa<SpatGraphComputeBatch, SpatScheduledComputeBatch>(op); }
std::optional<BlockArgument> SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); }
std::optional<BlockArgument> SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) {
return insertComputeInput(*this, idx, input, loc);
}
CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatGraphCompute>>
SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeOutput(*this, rewriter, idx, type, loc);
}
void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
}
std::optional<BlockArgument> SpatScheduledCompute::getWeightArgument(unsigned idx) {
return getComputeWeightArgument(*this, idx);
}
std::optional<BlockArgument> SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) {
return insertComputeInput(*this, idx, input, loc);
}
CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatScheduledCompute>>
SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeOutput(*this, rewriter, idx, type, loc);
}
void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
}
std::optional<BlockArgument> SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
std::optional<BlockArgument> SpatGraphComputeBatch::getWeightArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
std::optional<BlockArgument> SpatGraphComputeBatch::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
std::optional<BlockArgument> SpatGraphComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeBatchWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
std::optional<std::tuple<Value, BlockArgument>>
SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
@@ -167,52 +264,68 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatGraphComputeBatch>>
SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
continue;
if (index == 0) {
setNameFn(*outputArg, "out");
continue;
}
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
}
}
for (unsigned index = 0; index < getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
std::optional<BlockArgument> SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
std::optional<BlockArgument> SpatScheduledComputeBatch::getWeightArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + idx);
}
std::optional<BlockArgument> SpatScheduledComputeBatch::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatScheduledComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeBatchWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>>
SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
}
void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
@@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
builder.createBlock(bodyRegion);
}
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
OpResult SpatInParallelOp::getParentResult(int64_t idx) {
Operation* parent = getOperation()->getParentOp();
assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent");
return parent->getResult(idx);
}
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
+16
View File
@@ -26,3 +26,19 @@
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
namespace onnx_mlir {
namespace spatial {
bool isGraphComputeLike(mlir::Operation* op);
bool isGraphBatchComputeLike(mlir::Operation* op);
bool isScheduledComputeLike(mlir::Operation* op);
bool isScheduledBatchComputeLike(mlir::Operation* op);
bool isAnySpatialComputeLike(mlir::Operation* op);
bool isAnySpatialComputeBatchLike(mlir::Operation* op);
using SpatCompute = SpatGraphCompute;
using SpatComputeBatch = SpatGraphComputeBatch;
} // namespace spatial
} // namespace onnx_mlir
+260 -251
View File
@@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
return success();
}
template <typename ComputeOpTy>
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(op.getWeights().size());
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
auto weightArg = op.getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(op.getInputs().size());
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
auto inputArg = op.getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, op.getResultTypes());
printer << " ";
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
}
template <typename ComputeOpTy>
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
template <typename ComputeBatchOpTy>
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
auto laneArg = op.getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(op.getWeights().size());
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
auto weightArg = op.getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(op.getInputs().size());
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
auto inputArg = op.getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
if (op.getNumResults() != 0) {
outputArgs.reserve(op.getNumResults());
for (unsigned index = 0; index < op.getNumResults(); ++index) {
auto outputArg = op.getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(*laneArg);
printer << " = 0 to " << op.getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
if (op.getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, op.getResultTypes());
printer << " ";
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
}
template <typename ComputeBatchOpTy>
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
@@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success();
}
void SpatCompute::print(OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
auto laneArg = getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
if (getNumResults() != 0) {
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(*laneArg);
printer << " = 0 to " << getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
}
void SpatInParallelOp::print(OpAsmPrinter& printer) {
@@ -10,8 +10,9 @@ using namespace mlir;
namespace onnx_mlir {
namespace spatial {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
template <typename ComputeOpTy>
LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = compute.getBody().front();
if (!llvm::hasSingleElement(block))
return failure();
@@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber()));
results.push_back(compute.getOperand(blockArg.getArgNumber()));
continue;
}
}
@@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
return success();
}
LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
return foldComputeLike(*this, results);
}
LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
return foldComputeLike(*this, results);
}
} // namespace spatial
} // namespace onnx_mlir
+237 -76
View File
@@ -35,7 +35,8 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
return shapedType.getShape();
}
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
template <typename ComputeBatchOpTy>
static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
auto blockArg = dyn_cast<BlockArgument>(value);
@@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
return success();
}
static bool isStaticIndexExpr(Value value) {
if (matchConstantIndexValue(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) {
if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
return false;
return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr);
}
if (auto addOp = value.getDefiningOp<arith::AddIOp>())
return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs());
if (auto mulOp = value.getDefiningOp<arith::MulIOp>())
return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs());
return false;
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || matchConstantIndexValue(value))
if (value == laneArg || isStaticIndexExpr(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
@@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
}
auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp)
if (addOp)
return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs()))
|| (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs()));
auto mulOp = value.getDefiningOp<arith::MulIOp>();
if (!mulOp)
return false;
return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs()));
return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs()))
|| (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs()));
}
static LogicalResult
@@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
InFlightDiagnostic diagnostic =
ownerOp->emitOpError() << kind << " body may not capture external values";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
<< "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#"
<< operand.getOperandNumber() << " type=" << value.getType()
<< " category=" << (isa<TensorType>(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index"
: "scalar"));
if (Operation* definingOp = value.getDefiningOp())
diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'";
else if (auto blockArg = dyn_cast<BlockArgument>(value))
diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc())
<< "value is block argument #" << blockArg.getArgNumber() << " of '"
<< blockArg.getOwner()->getParentOp()->getName() << "'";
hasFailure = true;
}
});
return success(!hasFailure);
}
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
template <typename ComputeBatchOpTy>
static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) {
if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
@@ -344,144 +380,266 @@ LogicalResult SpatConcatOp::verify() {
return success();
}
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isa<SpatCompute, SpatComputeBatch>(op))
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute");
static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
static bool isKnownPhysicalLayout(StringRef layout) {
return layout == "dense_nchw" || layout == "nchw_row_strip";
}
static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!inputType || !outputType)
return op->emitOpError() << kind << " requires ranked tensor input and output types";
if (inputType.getElementType() != outputType.getElementType())
return op->emitOpError() << kind << " requires matching input/output element types";
return success();
}
LogicalResult SpatConv2DPlanOp::verify() {
auto inputType = dyn_cast<RankedTensorType>(getInput().getType());
auto weightType = dyn_cast<RankedTensorType>(getWeight().getType());
auto outputType = dyn_cast<RankedTensorType>(getOutput().getType());
if (!inputType || !weightType || !outputType)
return emitError("requires ranked tensor input, weight, and output");
if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4)
return emitError("requires rank-4 input, weight, and output tensors");
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (getPads().size() != 4)
return emitError("requires exactly four pad values");
if (getStrides().size() != 2)
return emitError("requires exactly two stride values");
if (getDilations().size() != 2)
return emitError("requires exactly two dilation values");
if (getGroup() < 1)
return emitError("requires group >= 1");
if (inputType.getElementType() != weightType.getElementType()
|| inputType.getElementType() != outputType.getElementType()) {
return emitError("requires matching input, weight, and output element types");
}
if (getBias()) {
auto biasType = dyn_cast<RankedTensorType>(getBias().getType());
if (!biasType)
return emitError("requires ranked tensor bias type");
if (biasType.getElementType() != outputType.getElementType())
return emitError("requires bias element type to match output element type");
}
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
unsigned expectedArgCount = getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
LogicalResult SpatReluPlanOp::verify() {
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
return success();
}
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
LogicalResult SpatReconciliatorOp::verify() {
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (!isKnownPhysicalLayout(getPhysicalLayout()))
return emitError("requires a known physical layout");
auto logicalType = dyn_cast<RankedTensorType>(getOutput().getType());
if (!logicalType)
return emitError("requires ranked tensor output");
auto offsets = getFragmentOffsets();
auto sizes = getFragmentSizes();
if (offsets.size() != sizes.size())
return emitError("fragment offset and size arrays must have the same length");
if (offsets.empty())
return success();
int64_t rank = logicalType.getRank();
if (rank <= 0 || offsets.size() % rank != 0)
return emitError("fragment metadata must be a whole number of rank-sized fragments");
ArrayRef<int64_t> shape = logicalType.getShape();
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
int64_t dim = index % rank;
int64_t offset = offsets[index];
int64_t size = sizes[index];
if (offset < 0 || size < 0)
return emitError("fragment offsets and sizes must be non-negative");
int64_t logicalDim = shape[dim];
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
return emitError("fragment bounds must stay within the logical tensor shape");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
return success();
}
LogicalResult SpatMaterializeLayoutOp::verify() {
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (!isKnownPhysicalLayout(getSourcePhysicalLayout()))
return emitError("requires a known source physical layout");
if (!isKnownPhysicalLayout(getTargetPhysicalLayout()))
return emitError("requires a known target physical layout");
return success();
}
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isAnySpatialComputeLike(op))
return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !isAnySpatialComputeLike(op->getParentOp());
});
})) {
return op->emitError("compute result used directly inside another Spatial compute body");
}
return success();
}
template <typename ComputeOpTy>
LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) {
auto& block = compute.getBody().front();
unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return compute.emitOpError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto blockArg = compute.getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return compute.emitOpError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
auto blockArg = compute.getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
return compute.emitOpError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
return compute.emitOpError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto resultTypes = compute.getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size())
return emitError("ComputeOp must have same number of results as yieldOp operands");
return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands");
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
return emitError("ComputeOp output must be of the same type as yieldOp operand");
return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand");
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand");
}
else {
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one");
}
}
else if (dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one");
}
}
}
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return emitError("ComputeOp block argument is not used");
if (failed(verifyStaticWeights(*this, "compute")))
for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex)
if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return compute.emitOpError("ComputeOp block argument is not used");
if (failed(verifyStaticWeights(compute, opName)))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName)))
return failure();
if (failed(verifyComputeResultsUses(this->getOperation())))
if (failed(verifyComputeResultsUses(compute.getOperation())))
return failure();
return success();
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); }
LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); }
template <typename ComputeBatchOpTy>
LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) {
int32_t count = batch.getLaneCount();
if (count <= 0)
return emitError("laneCount must be positive");
return batch.emitOpError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr)
return emitError("compute_batch coreIds attribute must be a dense i32 array");
return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
return emitError("compute_batch coreIds array length must match laneCount");
return batch.emitOpError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative");
return batch.emitOpError("compute_batch coreIds values must be non-negative");
DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be unique");
return batch.emitOpError("compute_batch coreIds values must be unique");
}
Block& block = getBody().front();
Block& block = batch.getBody().front();
if (block.getNumArguments() == 0)
return emitError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
return batch.emitOpError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = getLaneArgument();
return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = batch.getLaneArgument();
if (!laneArg || !laneArg->getType().isIndex())
return emitError("compute_batch first block argument must have index type");
return batch.emitOpError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) {
auto blockArg = batch.getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
auto blockArg = batch.getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly");
return batch.emitOpError("compute_batch input block argument types must match input operand types exactly");
}
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
auto blockArg = getOutputArgument(resultIndex);
for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) {
auto blockArg = batch.getOutputArgument(resultIndex);
if (!blockArg || blockArg->getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
return batch.emitOpError("compute_batch output block argument types must match result types exactly");
}
if (failed(verifyComputeResultsUses(this->getOperation())))
if (failed(verifyComputeResultsUses(batch.getOperation())))
return failure();
if (failed(verifyStaticWeights(*this, "compute_batch")))
if (failed(verifyStaticWeights(batch, opName)))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName)))
return failure();
return verifyBatchBody(*this, block);
return verifyBatchBody(batch, block);
}
LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); }
LogicalResult SpatScheduledComputeBatch::verify() {
return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch");
}
LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return emitOpError("expected spat.compute_batch parent");
if (batchOp.getNumResults() == 0)
Operation* parent = getOperation()->getParentOp();
if (!isAnySpatialComputeBatchLike(parent))
return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent");
if (parent->getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
auto laneArg = batchOp.getLaneArgument();
std::optional<BlockArgument> laneArg;
if (auto graphBatch = dyn_cast<SpatGraphComputeBatch>(parent))
laneArg = graphBatch.getLaneArgument();
else
laneArg = cast<SpatScheduledComputeBatch>(parent).getLaneArgument();
if (!laneArg)
return emitOpError("expected compute_batch lane block argument");
for (Operation& op : getRegion().front().getOperations()) {
@@ -494,7 +652,10 @@ LogicalResult SpatInParallelOp::verify() {
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get()))
if ((isa<SpatGraphComputeBatch>(parent)
&& !isBatchOutputArgument(cast<SpatGraphComputeBatch>(parent), destination.get()))
|| (isa<SpatScheduledComputeBatch>(parent)
&& !isBatchOutputArgument(cast<SpatScheduledComputeBatch>(parent), destination.get())))
return op.emitOpError("may only insert into a compute_batch output block argument");
}
File diff suppressed because it is too large Load Diff
@@ -40,10 +40,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
using SpatCompute = spatial::SpatGraphCompute;
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
if (failed(checkedCoreId))
@@ -209,7 +209,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
};
for (Operation& op : funcOp.getBody().front()) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(&op)) {
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
SmallVector<int32_t> coreIds;
@@ -229,7 +229,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
totalCrossbarCount += perInstanceCrossbarCount;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (auto batch = dyn_cast<spatial::SpatScheduledComputeBatch>(&op)) {
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
@@ -353,7 +353,17 @@ public:
void runOnOperation() override {
func::FuncOp func = getOperation();
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
signalPassFailure();
return;
}
mergeTriviallyConnectedComputes(func);
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
signalPassFailure();
return;
}
const spatial::MergeScheduleResult* analysisResult = nullptr;
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
@@ -367,8 +377,8 @@ public:
signalPassFailure();
return;
}
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
if (failed(verifyScheduledSpatialInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
signalPassFailure();
return;
}