DeadLock
This commit is contained in:
@@ -26,7 +26,7 @@ def SpatTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatCompute : SpatOp<"compute",
|
||||
class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
@@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute",
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
|
||||
let extraClassDeclaration = [{
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
@@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute",
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
|
||||
let extraClassDeclaration = [{
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
}
|
||||
|
||||
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||
@@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
|
||||
let extraClassDeclaration = [{
|
||||
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
@@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
}
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
|
||||
let extraClassDeclaration = [{
|
||||
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||
Pure,
|
||||
Terminator,
|
||||
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||
HasParent<"SpatComputeBatch">,
|
||||
] # GraphRegionNoTerminator.traits> {
|
||||
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
||||
let summary = "Parallel combining terminator for resultful Spatial compute batches";
|
||||
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
@@ -159,6 +192,82 @@ def SpatConcatOp : SpatOp<"concat", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Planning
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
|
||||
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input,
|
||||
SpatTensor:$weight,
|
||||
Optional<SpatTensor>:$bias,
|
||||
DenseI64ArrayAttr:$pads,
|
||||
DenseI64ArrayAttr:$strides,
|
||||
DenseI64ArrayAttr:$dilations,
|
||||
I64Attr:$group,
|
||||
StrAttr:$logicalLayout
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
|
||||
let summary = "Layout-aware ReLU planning op";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input,
|
||||
StrAttr:$logicalLayout
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
|
||||
let summary = "Passive logical-to-physical layout selection record";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input,
|
||||
StrAttr:$logicalLayout,
|
||||
StrAttr:$physicalLayout,
|
||||
DenseI64ArrayAttr:$fragmentOffsets,
|
||||
DenseI64ArrayAttr:$fragmentSizes,
|
||||
StrAttr:$indexMap
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
|
||||
let summary = "Explicit layout conversion or materialization barrier";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input,
|
||||
StrAttr:$logicalLayout,
|
||||
StrAttr:$sourcePhysicalLayout,
|
||||
StrAttr:$targetPhysicalLayout
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -29,11 +29,19 @@ std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx,
|
||||
}
|
||||
|
||||
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
if (auto compute = dyn_cast<SpatGraphCompute>(op)) {
|
||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
return;
|
||||
}
|
||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
if (auto compute = dyn_cast<SpatScheduledCompute>(op)) {
|
||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
return;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatGraphComputeBatch>(op)) {
|
||||
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
return;
|
||||
}
|
||||
cast<SpatScheduledComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
}
|
||||
|
||||
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
||||
@@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) {
|
||||
return weights;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), getWeights().size() + idx);
|
||||
template <typename ComputeOpTy>
|
||||
std::optional<BlockArgument> getComputeWeightArgument(ComputeOpTy compute, unsigned idx) {
|
||||
return getBlockArgument(compute.getBody(), idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||
auto index = std::distance(getWeights().begin(), existing);
|
||||
return {
|
||||
{*existing, *getWeightArgument(index)}
|
||||
};
|
||||
template <typename ComputeOpTy>
|
||||
std::optional<BlockArgument> getComputeInputArgument(ComputeOpTy compute, unsigned idx) {
|
||||
return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx);
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) {
|
||||
auto index = std::distance(compute.getWeights().begin(), existing);
|
||||
return {{*existing, *getComputeWeightArgument(compute, index)}};
|
||||
}
|
||||
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
unsigned weightCount = compute.getWeights().size();
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
compute.getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
|
||||
compute.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
template <typename ComputeBatchOpTy>
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) {
|
||||
auto index = std::distance(batch.getWeights().begin(), existing);
|
||||
return {{*existing, *batch.getWeightArgument(index)}};
|
||||
}
|
||||
|
||||
unsigned weightCount = batch.getWeights().size();
|
||||
unsigned inputCount = batch.getInputs().size();
|
||||
batch.getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||
batch.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
|
||||
auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg);
|
||||
}
|
||||
|
||||
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
|
||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
||||
newCompute->setAttrs((*this)->getAttrs());
|
||||
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||
getResult(oldResultIdx)
|
||||
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||
template <typename ComputeOpTy>
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = compute.getWeights().size();
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
compute.getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
template <typename ComputeOpTy>
|
||||
void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
if (auto weightArg = getWeightArgument(index))
|
||||
for (unsigned index = 0; index < compute.getWeights().size(); ++index)
|
||||
if (auto weightArg = compute.getWeightArgument(index))
|
||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
if (auto inputArg = getInputArgument(index))
|
||||
for (unsigned index = 0; index < compute.getInputs().size(); ++index)
|
||||
if (auto inputArg = compute.getInputArgument(index))
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||
template <typename ComputeOpTy>
|
||||
FailureOr<std::tuple<OpResult, ComputeOpTy>>
|
||||
insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > compute.getNumResults())
|
||||
return failure();
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
||||
rewriter.setInsertionPoint(compute.getOperation());
|
||||
SmallVector<Type> resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newCompute =
|
||||
ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs());
|
||||
newCompute->setAttrs(compute->getAttrs());
|
||||
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx)
|
||||
compute.getResult(oldResultIdx)
|
||||
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(compute.getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||
}
|
||||
|
||||
template <typename ComputeBatchOpTy>
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, ComputeBatchOpTy>>
|
||||
insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > batch.getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(batch.getOperation());
|
||||
SmallVector<Type> resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newBatch =
|
||||
ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs());
|
||||
newBatch->setAttrs(batch->getAttrs());
|
||||
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||
if (newBatch.getBody().empty()) {
|
||||
rewriter.eraseOp(newBatch);
|
||||
return failure();
|
||||
}
|
||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx)
|
||||
batch.getResult(oldResultIdx)
|
||||
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(batch.getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isGraphComputeLike(Operation* op) { return isa<SpatGraphCompute, SpatGraphComputeBatch>(op); }
|
||||
|
||||
bool isGraphBatchComputeLike(Operation* op) { return isa<SpatGraphComputeBatch>(op); }
|
||||
|
||||
bool isScheduledComputeLike(Operation* op) { return isa<SpatScheduledCompute, SpatScheduledComputeBatch>(op); }
|
||||
|
||||
bool isScheduledBatchComputeLike(Operation* op) { return isa<SpatScheduledComputeBatch>(op); }
|
||||
|
||||
bool isAnySpatialComputeLike(Operation* op) {
|
||||
return isa<SpatGraphCompute, SpatGraphComputeBatch, SpatScheduledCompute, SpatScheduledComputeBatch>(op);
|
||||
}
|
||||
|
||||
bool isAnySpatialComputeBatchLike(Operation* op) { return isa<SpatGraphComputeBatch, SpatScheduledComputeBatch>(op); }
|
||||
|
||||
std::optional<BlockArgument> SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); }
|
||||
std::optional<BlockArgument> SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
return insertComputeWeight(*this, idx, weight, loc);
|
||||
}
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||
return insertComputeInput(*this, idx, input, loc);
|
||||
}
|
||||
CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
FailureOr<std::tuple<OpResult, SpatGraphCompute>>
|
||||
SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
return insertComputeOutput(*this, rewriter, idx, type, loc);
|
||||
}
|
||||
void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatScheduledCompute::getWeightArgument(unsigned idx) {
|
||||
return getComputeWeightArgument(*this, idx);
|
||||
}
|
||||
std::optional<BlockArgument> SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
return insertComputeWeight(*this, idx, weight, loc);
|
||||
}
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||
return insertComputeInput(*this, idx, input, loc);
|
||||
}
|
||||
CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
FailureOr<std::tuple<OpResult, SpatScheduledCompute>>
|
||||
SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
return insertComputeOutput(*this, rewriter, idx, type, loc);
|
||||
}
|
||||
void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||
std::optional<BlockArgument> SpatGraphComputeBatch::getWeightArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), 1 + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
std::optional<BlockArgument> SpatGraphComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
std::optional<BlockArgument> SpatGraphComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||
auto index = std::distance(getWeights().begin(), existing);
|
||||
return {
|
||||
{*existing, *getWeightArgument(index)}
|
||||
};
|
||||
}
|
||||
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
return insertComputeBatchWeight(*this, idx, weight, loc);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
@@ -167,52 +264,68 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newBatch =
|
||||
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
||||
newBatch->setAttrs((*this)->getAttrs());
|
||||
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||
if (newBatch.getBody().empty()) {
|
||||
rewriter.eraseOp(newBatch);
|
||||
return failure();
|
||||
}
|
||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||
getResult(oldResultIdx)
|
||||
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||
CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatGraphComputeBatch>>
|
||||
SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
if (auto laneArg = getLaneArgument())
|
||||
setNameFn(*laneArg, "lane");
|
||||
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
continue;
|
||||
if (index == 0) {
|
||||
setNameFn(*outputArg, "out");
|
||||
continue;
|
||||
}
|
||||
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
if (auto weightArg = getWeightArgument(index))
|
||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
if (auto inputArg = getInputArgument(index))
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
std::optional<BlockArgument> SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||
std::optional<BlockArgument> SpatScheduledComputeBatch::getWeightArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), 1 + idx);
|
||||
}
|
||||
std::optional<BlockArgument> SpatScheduledComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||
}
|
||||
std::optional<BlockArgument> SpatScheduledComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
return insertComputeBatchWeight(*this, idx, weight, loc);
|
||||
}
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>>
|
||||
SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
|
||||
}
|
||||
void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
if (auto laneArg = getLaneArgument())
|
||||
setNameFn(*laneArg, "lane");
|
||||
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
@@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||
builder.createBlock(bodyRegion);
|
||||
}
|
||||
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) {
|
||||
Operation* parent = getOperation()->getParentOp();
|
||||
assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent");
|
||||
return parent->getResult(idx);
|
||||
}
|
||||
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||
|
||||
|
||||
@@ -26,3 +26,19 @@
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
bool isGraphComputeLike(mlir::Operation* op);
|
||||
bool isGraphBatchComputeLike(mlir::Operation* op);
|
||||
bool isScheduledComputeLike(mlir::Operation* op);
|
||||
bool isScheduledBatchComputeLike(mlir::Operation* op);
|
||||
bool isAnySpatialComputeLike(mlir::Operation* op);
|
||||
bool isAnySpatialComputeBatchLike(mlir::Operation* op);
|
||||
|
||||
using SpatCompute = SpatGraphCompute;
|
||||
using SpatComputeBatch = SpatGraphComputeBatch;
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(op.getWeights().size());
|
||||
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||
auto weightArg = op.getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(op.getInputs().size());
|
||||
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||
auto inputArg = op.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
|
||||
|
||||
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreId)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
template <typename ComputeBatchOpTy>
|
||||
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
|
||||
auto laneArg = op.getLaneArgument();
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(op.getWeights().size());
|
||||
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||
auto weightArg = op.getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(op.getInputs().size());
|
||||
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||
auto inputArg = op.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
if (!laneArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
if (op.getNumResults() != 0) {
|
||||
outputArgs.reserve(op.getNumResults());
|
||||
for (unsigned index = 0; index < op.getNumResults(); ++index) {
|
||||
auto outputArg = op.getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
outputArgs.push_back(*outputArg);
|
||||
}
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printer.printOperand(*laneArg);
|
||||
printer << " = 0 to " << op.getLaneCount();
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||
if (op.getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
|
||||
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
}
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
template <typename ComputeBatchOpTy>
|
||||
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
||||
}
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreId)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
auto laneArg = getLaneArgument();
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
if (!laneArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
if (getNumResults() != 0) {
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
outputArgs.push_back(*outputArg);
|
||||
}
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printer.printOperand(*laneArg);
|
||||
printer << " = 0 to " << getLaneCount();
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
}
|
||||
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
|
||||
}
|
||||
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
|
||||
}
|
||||
|
||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||
|
||||
@@ -10,8 +10,9 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
template <typename ComputeOpTy>
|
||||
LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = compute.getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
@@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
|
||||
for (Value yieldedValue : yieldOp.getOperands()) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArg.getOwner() == &block) {
|
||||
results.push_back(getOperand(blockArg.getArgNumber()));
|
||||
results.push_back(compute.getOperand(blockArg.getArgNumber()));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
return foldComputeLike(*this, results);
|
||||
}
|
||||
|
||||
LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
return foldComputeLike(*this, results);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -35,7 +35,8 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
template <typename ComputeBatchOpTy>
|
||||
static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return false;
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
@@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isStaticIndexExpr(Value value) {
|
||||
if (matchConstantIndexValue(value))
|
||||
return true;
|
||||
|
||||
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||
if (affineApply) {
|
||||
if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
|
||||
return false;
|
||||
return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr);
|
||||
}
|
||||
|
||||
if (auto addOp = value.getDefiningOp<arith::AddIOp>())
|
||||
return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs());
|
||||
|
||||
if (auto mulOp = value.getDefiningOp<arith::MulIOp>())
|
||||
return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
if (value == laneArg || matchConstantIndexValue(value))
|
||||
if (value == laneArg || isStaticIndexExpr(value))
|
||||
return true;
|
||||
|
||||
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||
@@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
}
|
||||
|
||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||
if (!addOp)
|
||||
if (addOp)
|
||||
return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs()))
|
||||
|| (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs()));
|
||||
|
||||
auto mulOp = value.getDefiningOp<arith::MulIOp>();
|
||||
if (!mulOp)
|
||||
return false;
|
||||
return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs()))
|
||||
|| (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs()));
|
||||
return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs()))
|
||||
|| (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs()));
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
@@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
InFlightDiagnostic diagnostic =
|
||||
ownerOp->emitOpError() << kind << " body may not capture external values";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
<< "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#"
|
||||
<< operand.getOperandNumber() << " type=" << value.getType()
|
||||
<< " category=" << (isa<TensorType>(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index"
|
||||
: "scalar"));
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'";
|
||||
else if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc())
|
||||
<< "value is block argument #" << blockArg.getArgNumber() << " of '"
|
||||
<< blockArg.getOwner()->getParentOp()->getName() << "'";
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
template <typename ComputeBatchOpTy>
|
||||
static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) {
|
||||
if (batchOp.getNumResults() == 0) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
@@ -344,144 +380,266 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute");
|
||||
static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
|
||||
|
||||
static bool isKnownPhysicalLayout(StringRef layout) {
|
||||
return layout == "dense_nchw" || layout == "nchw_row_strip";
|
||||
}
|
||||
|
||||
static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
||||
if (!inputType || !outputType)
|
||||
return op->emitOpError() << kind << " requires ranked tensor input and output types";
|
||||
if (inputType.getElementType() != outputType.getElementType())
|
||||
return op->emitOpError() << kind << " requires matching input/output element types";
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatConv2DPlanOp::verify() {
|
||||
auto inputType = dyn_cast<RankedTensorType>(getInput().getType());
|
||||
auto weightType = dyn_cast<RankedTensorType>(getWeight().getType());
|
||||
auto outputType = dyn_cast<RankedTensorType>(getOutput().getType());
|
||||
if (!inputType || !weightType || !outputType)
|
||||
return emitError("requires ranked tensor input, weight, and output");
|
||||
if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4)
|
||||
return emitError("requires rank-4 input, weight, and output tensors");
|
||||
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||
return emitError("requires a known logical layout");
|
||||
if (getPads().size() != 4)
|
||||
return emitError("requires exactly four pad values");
|
||||
if (getStrides().size() != 2)
|
||||
return emitError("requires exactly two stride values");
|
||||
if (getDilations().size() != 2)
|
||||
return emitError("requires exactly two dilation values");
|
||||
if (getGroup() < 1)
|
||||
return emitError("requires group >= 1");
|
||||
if (inputType.getElementType() != weightType.getElementType()
|
||||
|| inputType.getElementType() != outputType.getElementType()) {
|
||||
return emitError("requires matching input, weight, and output element types");
|
||||
}
|
||||
if (getBias()) {
|
||||
auto biasType = dyn_cast<RankedTensorType>(getBias().getType());
|
||||
if (!biasType)
|
||||
return emitError("requires ranked tensor bias type");
|
||||
if (biasType.getElementType() != outputType.getElementType())
|
||||
return emitError("requires bias element type to match output element type");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
LogicalResult SpatReluPlanOp::verify() {
|
||||
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan")))
|
||||
return failure();
|
||||
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||
return emitError("requires a known logical layout");
|
||||
return success();
|
||||
}
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
LogicalResult SpatReconciliatorOp::verify() {
|
||||
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
|
||||
return failure();
|
||||
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||
return emitError("requires a known logical layout");
|
||||
if (!isKnownPhysicalLayout(getPhysicalLayout()))
|
||||
return emitError("requires a known physical layout");
|
||||
|
||||
auto logicalType = dyn_cast<RankedTensorType>(getOutput().getType());
|
||||
if (!logicalType)
|
||||
return emitError("requires ranked tensor output");
|
||||
|
||||
auto offsets = getFragmentOffsets();
|
||||
auto sizes = getFragmentSizes();
|
||||
if (offsets.size() != sizes.size())
|
||||
return emitError("fragment offset and size arrays must have the same length");
|
||||
if (offsets.empty())
|
||||
return success();
|
||||
|
||||
int64_t rank = logicalType.getRank();
|
||||
if (rank <= 0 || offsets.size() % rank != 0)
|
||||
return emitError("fragment metadata must be a whole number of rank-sized fragments");
|
||||
|
||||
ArrayRef<int64_t> shape = logicalType.getShape();
|
||||
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
|
||||
int64_t dim = index % rank;
|
||||
int64_t offset = offsets[index];
|
||||
int64_t size = sizes[index];
|
||||
if (offset < 0 || size < 0)
|
||||
return emitError("fragment offsets and sizes must be non-negative");
|
||||
int64_t logicalDim = shape[dim];
|
||||
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
|
||||
return emitError("fragment bounds must stay within the logical tensor shape");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatMaterializeLayoutOp::verify() {
|
||||
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout")))
|
||||
return failure();
|
||||
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||
return emitError("requires a known logical layout");
|
||||
if (!isKnownPhysicalLayout(getSourcePhysicalLayout()))
|
||||
return emitError("requires a known source physical layout");
|
||||
if (!isKnownPhysicalLayout(getTargetPhysicalLayout()))
|
||||
return emitError("requires a known target physical layout");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isAnySpatialComputeLike(op))
|
||||
return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !isAnySpatialComputeLike(op->getParentOp());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("compute result used directly inside another Spatial compute body");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) {
|
||||
auto& block = compute.getBody().front();
|
||||
unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return compute.emitOpError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto blockArg = compute.getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return compute.emitOpError("compute weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto blockArg = compute.getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
return compute.emitOpError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("ComputeOp must have a single yield operation");
|
||||
return compute.emitOpError("ComputeOp must have a single yield operation");
|
||||
|
||||
auto resultTypes = getResultTypes();
|
||||
auto resultTypes = compute.getResultTypes();
|
||||
auto yieldTypes = yieldOp->getOperandTypes();
|
||||
if (resultTypes.size() != yieldTypes.size())
|
||||
return emitError("ComputeOp must have same number of results as yieldOp operands");
|
||||
return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands");
|
||||
|
||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||
auto resultType = std::get<0>(it);
|
||||
auto yieldType = std::get<1>(it);
|
||||
|
||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
||||
return emitError("ComputeOp output must be of the same type as yieldOp operand");
|
||||
return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand");
|
||||
|
||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
||||
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
|
||||
return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand");
|
||||
}
|
||||
else {
|
||||
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
||||
return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
||||
}
|
||||
}
|
||||
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
||||
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
||||
return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
if (failed(verifyStaticWeights(*this, "compute")))
|
||||
for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex)
|
||||
if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||
return compute.emitOpError("ComputeOp block argument is not used");
|
||||
if (failed(verifyStaticWeights(compute, opName)))
|
||||
return failure();
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName)))
|
||||
return failure();
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
if (failed(verifyComputeResultsUses(compute.getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
int32_t count = getLaneCount();
|
||||
LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); }
|
||||
|
||||
LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); }
|
||||
|
||||
template <typename ComputeBatchOpTy>
|
||||
LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) {
|
||||
int32_t count = batch.getLaneCount();
|
||||
if (count <= 0)
|
||||
return emitError("laneCount must be positive");
|
||||
return batch.emitOpError("laneCount must be positive");
|
||||
|
||||
auto laneCountSz = static_cast<size_t>(count);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
||||
if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
if (!coreIdsAttr)
|
||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
return batch.emitOpError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
return batch.emitOpError("compute_batch coreIds values must be non-negative");
|
||||
DenseSet<int32_t> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
return emitError("compute_batch coreIds values must be unique");
|
||||
return batch.emitOpError("compute_batch coreIds values must be unique");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
Block& block = batch.getBody().front();
|
||||
if (block.getNumArguments() == 0)
|
||||
return emitError("compute_batch body must have exactly one lane block argument");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
return batch.emitOpError("compute_batch body must have exactly one lane block argument");
|
||||
unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||
auto laneArg = getLaneArgument();
|
||||
return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||
auto laneArg = batch.getLaneArgument();
|
||||
if (!laneArg || !laneArg->getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
return batch.emitOpError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) {
|
||||
auto blockArg = batch.getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
|
||||
auto blockArg = batch.getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
return batch.emitOpError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
auto blockArg = getOutputArgument(resultIndex);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) {
|
||||
auto blockArg = batch.getOutputArgument(resultIndex);
|
||||
if (!blockArg || blockArg->getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
return batch.emitOpError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
if (failed(verifyComputeResultsUses(batch.getOperation())))
|
||||
return failure();
|
||||
if (failed(verifyStaticWeights(*this, "compute_batch")))
|
||||
if (failed(verifyStaticWeights(batch, opName)))
|
||||
return failure();
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||
if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName)))
|
||||
return failure();
|
||||
return verifyBatchBody(*this, block);
|
||||
return verifyBatchBody(batch, block);
|
||||
}
|
||||
|
||||
LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); }
|
||||
|
||||
LogicalResult SpatScheduledComputeBatch::verify() {
|
||||
return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatInParallelOp::verify() {
|
||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return emitOpError("expected spat.compute_batch parent");
|
||||
if (batchOp.getNumResults() == 0)
|
||||
Operation* parent = getOperation()->getParentOp();
|
||||
if (!isAnySpatialComputeBatchLike(parent))
|
||||
return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent");
|
||||
if (parent->getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
std::optional<BlockArgument> laneArg;
|
||||
if (auto graphBatch = dyn_cast<SpatGraphComputeBatch>(parent))
|
||||
laneArg = graphBatch.getLaneArgument();
|
||||
else
|
||||
laneArg = cast<SpatScheduledComputeBatch>(parent).getLaneArgument();
|
||||
if (!laneArg)
|
||||
return emitOpError("expected compute_batch lane block argument");
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
@@ -494,7 +652,10 @@ LogicalResult SpatInParallelOp::verify() {
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
for (OpOperand& destination : destinations)
|
||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
||||
if ((isa<SpatGraphComputeBatch>(parent)
|
||||
&& !isBatchOutputArgument(cast<SpatGraphComputeBatch>(parent), destination.get()))
|
||||
|| (isa<SpatScheduledComputeBatch>(parent)
|
||||
&& !isBatchOutputArgument(cast<SpatScheduledComputeBatch>(parent), destination.get())))
|
||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||
}
|
||||
|
||||
|
||||
+4645
-263
File diff suppressed because it is too large
Load Diff
@@ -40,10 +40,10 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
using SpatCompute = spatial::SpatGraphCompute;
|
||||
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
@@ -209,7 +209,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
};
|
||||
|
||||
for (Operation& op : funcOp.getBody().front()) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(&op)) {
|
||||
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
@@ -229,7 +229,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
totalCrossbarCount += perInstanceCrossbarCount;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (auto batch = dyn_cast<spatial::SpatScheduledComputeBatch>(&op)) {
|
||||
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||
@@ -353,7 +353,17 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
mergeTriviallyConnectedComputes(func);
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
||||
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||
@@ -367,8 +377,8 @@ public:
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (failed(verifySpatialCommunicationInvariants(func))) {
|
||||
func.emitOpError("merged Spatial communication invariant verification failed");
|
||||
if (failed(verifyScheduledSpatialInvariants(func))) {
|
||||
func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user