Files
Raptor/src/PIM/Dialect/Spatial/SpatialOps.cpp
T
NiccoloN a50e77ff38
Validate Operations / validate-operations (push) Has been cancelled
refactorone
2026-05-20 19:06:41 +02:00

97 lines
3.2 KiB
C++

#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include <string>
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
return getBody().front().getArgument(getWeights().size() + idx);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
}
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
}
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getNumResults(); ++index) {
if (index == 0) {
setNameFn(getOutputArgument(index), "out");
continue;
}
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
}
}
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpBuilder::InsertionGuard guard(builder);
Region* bodyRegion = result.addRegion();
builder.createBlock(bodyRegion);
}
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
return getRegion().front().getOperations();
}
void SpatialDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
>();
}
} // namespace spatial
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"