This commit is contained in:
@@ -1,10 +1,74 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
|
||||
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(getWeights().size() + idx);
|
||||
}
|
||||
|
||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
if (index == 0) {
|
||||
setNameFn(getOutputArgument(index), "out");
|
||||
continue;
|
||||
}
|
||||
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Region* bodyRegion = result.addRegion();
|
||||
builder.createBlock(bodyRegion);
|
||||
}
|
||||
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
|
||||
return getRegion().front().getOperations();
|
||||
}
|
||||
|
||||
void SpatialDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
|
||||
Reference in New Issue
Block a user