# Data-Tiling + Multi-Device
## Overview
Prototype branch (based on main branch): https://github.com/iree-org/iree/pull/18738
IREE is designed first and foremost as
a cross-compiler with multi-targeting at its core. The design doc describes the idea that how to run data-tiling + multi-device in IREE. It involves several component changes, including Encoding, Stream, HAL, IREECPU/GPU dialects. The core keys are all around how we encode the information in the encodings and implement it through interfaces. Using interfaces allows us to layer components better, and makes encodings more robust in terms of modularity. I.e., backends can implement encodings through interface methods, and backends can also discard encodings when they do not care encodings.
"Multi-device" here means several things and all are
accomplished with the same mechanism: a single device that may be one of
multiple types (multiple CPU/GPU archs, CPU _or_ GPU, etc), multiple
homogeneous devices (4 of the same exact GPUs accessed through the same
runtime HAL driver), multiple heterogeneous devices (a CPU and a
GPU/NPU/etc), and optional devices (a CPU with some portions offloaded
to a GPU/NPU if it's compatible and available at runtime). See the [writeup](https://github.com/iree-org/iree/commit/d39c3c56682e006e842b32aa6f38c272f77c8f3c) for more details.
The prototype scopes the work for heterogeneous computing. It does not support the cases that a device has several targets (e.g., `device = [#device_a, #device_b]`). The outcomes are
1. It provides a mechanism for the consistency between host/device w.r.t. encodings.
2. The host and device can agree with the total storage sizes through the new interface.
3. Stream is able to specialize encodings for host codes and device codes.
4. We'll be able to query target encoding layouts from interface methods. I.e., we'll be able to materialize encodings into layouts for **device_a** and execute the relayout ops on **device_b**.
5. Unblock the multi-device project from data-tiling workarounds.
6. Furthermore, once we move the core materialization logics to the attribute that implements the interfaces, we should be able to do an analysis and materialize the initializers/global variables with encodings in the compilation phase.
Below is the dependency graph in the [prototype](https://github.com/iree-org/iree/pull/18738). The red lines are additional dependency and the black lines are existing dependency. It mostly introduces interface methods in the Encoding dialect, and the other components implement passes/methods to propagate/handle encodings in their own passes.
```mermaid
flowchart TD
StreamIR(Stream/IRs) --> Utils[iree/compiler/Util]
StreamIR --> Encoding(Encoding/IR)
StreamX[Stream/Transforms] --> StreamIR
StreamX --> FlowIR[Flow/IR]
StreamX --> HALIR[HAL/IR]
HALIR --> StreamIR
HALIR --> Utils
HALIR --> Encoding
HALX[HAL/Transforms] --> HALIR
HALX --> StreamIR
HALX --> Utils
HALX --> Codegen(IREECodegenDialect)
Codegen --> Encoding
Utils --> Encoding
Targets(plugins/target) --> Encoding
Targets --> Codegen
HALX --> Encoding
linkStyle 1,7,12,13,16 stroke:#F00,color:red;
```
Requirements:
1. Consistency between host/device.
2. Specialization for unique targets (not all possible).
3. Do not add additional dependencies between IREE core dialects, except the Encoding dialect. All the dialects can depend on the Encoding dialect.
4. Demonstrate an e2e example for running data-tiling + heterogeneous computing in IREE.
5. etc.
Note: there is a hidden dependency from `Stream/IR` to `HAL/IR`, which is not showed up in the diagram. Because the dependency is hidden by using dialect interface. We'll talk about it in the detailed design.
## Terminology
Highlight the fact that **Device != Executable**. Devices are what you would think (NVIDIA GPU 2 in your machine, CPU NUMA node 0 with a thread pool, etc). Executables are our compiled dispatch code that runs on those devices (ELFs, SPIR-V blobs, VMVX VMFBs, etc) that are produced by a target backend (llvm-cpu, vmvx, cuda, etc).
There are some stream operations that have an optional `#stream.affinity` attribute. Those operations implement the `AffinityOpInterface` interface methods. The interface denotes a stream affinity for ops and specify the kind of environment the ops are expected run in.
E.g.,
```mlir=
%18 = stream.async.dispatch on(#hal.device.affinity<@device_a>)
@foo_dispatch_2::@foo_dispatch_2_matmul_DxDxD_f32
[%1, %7, %0, %8](%14[%c0 to %13 for %13], %16[%c0 to %15 for %15], %1, %7, %0, %8)
: (!stream.resource<*>{%13}, !stream.resource<*>{%15}, index, index, index, index)
-> !stream.resource<*>{%17}
```
- Execution Affinity: The device that the dispatch will be executed on. In the above example, the execution affinity is `@device_a`.
- Resource Affinity: The device that produces the resource (i.e., the input operands of the dispatch).
Some stream ops are Stream tensor op (e.g., stream.tensor.sizeof). They all have `IREE::Stream::TensorPhaseOp` op trait.
## Design Details
There are few sections in the doc:
- Encoding specialization, which is the most critical part of the doc.
- Encode host/device Tensors, which includes how we allocate buffers in the host codes.
- How to codegen encodings with the target information.
- End-To-End Demo.
### Encoding Specialization Idea
Stream dialect is designed to model execution partitioning and scheduling. At this phase, we know the execution affinities and resource affinities for each `stream.async.dispatch` op; we also know the affinity information for Stream tensor ops. It means that we have enough information to specialize the encodings because we do know the affinities. I.e., they can be queried from Stream/HAL analysis.
The core idea is adding an additional field to the encoding attribute which carries the target information, and updating the encodings in host codes and device codes (i.e., executables). They need to be updated together in the SpecializeEncodingsPass. Because we can't have the IR be in an inconsistent state where the host and device don't agree on the encodings. Below sections describe what is being added for the pass, and the new information that we could have for codegen backends.
#### Additional Field in EncodingAttr
The proposal is to have an optional array of targets in the [EncodingAttr](https://github.com/iree-org/iree/blob/fa6aa1c9489df50b232c7339ec09860780971e03/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td#L60-L68), which is a list of attributes. For the Stream tensor ops, we can gather the execution affinity and update the encodings. E.g.,
```mlir=
#encoding = #iree_encoding.encoding<
operand_index = 0 : index,
op_type = matmul,
element_types = [f32, f32, f32],
user_indexing_maps = [#map, #map1, #map2]>
%14 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>)
tensor<?x?xf32, #encoding>{%0, %1} : index
->
#new_encoding = #iree_encoding.encoding<
operand_index = 0 : index,
op_type = matmul,
element_types = [f32, f32, f32],
user_indexing_maps = [#map, #map1, #map2],
targets = [whatever_attribute_related_to_device_a]>
%14 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>)
tensor<?x?xf32, #new_encoding>{%0, %1} : index
```
For the dispatches, we can gather the execution affinity and the resource affinities for each `stream.async.dispatch`; clone the executables and update the dispatch ops. This will be disccused in the later sections.
#### EncodingSolverInterfaceAttr
So the problem becomes -- what information is needed in the optional `targets` field? Our goals are (a) do not add additional dependencies between core dialects and (b) backends are able to provide the information for an encoding (e.g., `calculateStorageElementCountInBytes`). Thus, we introduce the `EncodingSolverInterfaceAttr` attribute interface; attach them to the `targets` field in the encoding. If one of the attribute in `targets` does not implement the interface, we drop the encoding.
Note again that they all happen in the SpecializeEncoding pass, so we can have the consistency between host code and device code.
Below is an example that helps resolve the storage size calculation.
```tablegen=
def IREEEncoding_EncodingSolverInterfaceAttr :
AttrInterface<"EncodingSolverInterfaceAttr"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";
let description = [{
Interface used to resolve encoding data from backends.
}];
let methods = [
InterfaceMethod<
/*desc=*/[{
Returns the materailized shape of the type.
}],
/*retTy=*/"::mlir::OpFoldResult",
/*methodName=*/"calculateStorageElementCountInBytes",
/*args=*/(ins
"::mlir::OpBuilder &":$builder,
"RankedTensorType":$type,
"ValueRange":$dynamicDims
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
}]
>
]
}
```
It is critical for encoding host tensors like `stream.tensor.sizeof` ops. A `stream.tensor.sizeof` op takes a tensor type (with an optional encoding) and dynamic dimensions (if present).
```mlir=
%sizeof = stream.tensor.sizeof on(#hal.device.affinity<@device_a>)
tensor<?x?xf32, #encoding>{%0, %1} : index
```
We introduce `calculateStorageElementCountInBytes` interface method, and codegen can implement it (e.g., see the [prototype](https://github.com/hanhanW/iree/blob/a07be105ea12302208aae45c51a7bd246ccbe3a7/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/IREECPUAttrs.td#L18-L52)) based on the target configuration. We attach the encoding solver to `IREE::HAL::ExecutableTargetAttr`'s configuration field, which carries the additional informations when we create the target. It is defined by `plugins/target/*` implementation. The claim is that the target can add the attribute to the dictionary attribute when they have an encoding implementation. Simliar things are already happening in [ROCM targets](https://github.com/iree-org/iree/blob/d1a991cd98f4d6d32f459b6cb205e4ff89011997/compiler/plugins/target/ROCM/ROCMTarget.cpp#L246-L249). The IREEGPU attribute is populated and attached when IREE creates the executable target attribute.
#### Stream Dialect Interface (Hidden Dependency)
As mentioned, we can gather the device affinity (which is a list of `hal.devices` attribute) at stream level. However, the tricky part is that Stream does not know anything about HAL IRs. Because Stream dialect does not depend on HAL dialect. This conversion (i.e., convert a device to a list of EncodingSolver attribute) is done by dialect interface. The prototype defines `AffinityAnalysisDialectInterface` in Stream, and HAL [implements the method using
DeviceAnalysis](https://github.com/hanhanW/iree/blob/a07be105ea12302208aae45c51a7bd246ccbe3a7/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp#L122-L167).
In the SpecializeEncoding pass, we can iterate on the registered dialects and see if any of them implements the dialect interface. See the [prototype](https://github.com/hanhanW/iree/blob/a07be105ea12302208aae45c51a7bd246ccbe3a7/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp#L48-L69) if you're interested in how the code looks like.
With the mechanism, we're able to update all the encoding types for Stream tensor ops without adding a dependency from Stream to HAL.
*Side note: It might be confusing that why the Stream dialect sees device affinity but it does not depend on the HAL dialect. The reason is that they are all just attributes to the Stream dialect, and we rely on interfaces to do the analysis. So the Stream dialect does not need to explicitly depend on the HAL dialect.*
```cpp=
namespace mlir::iree_compiler::IREE::Stream {
class AffinityAnalysisDialectInterface
: public DialectInterface::Base<AffinityAnalysisDialectInterface> {
public:
AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {}
virtual std::function<LogicalResult(AffinityAttr, Operation *,
SetVector<Attribute> &)>
makeTargetResolver(ModuleOp moduleOp) const = 0;
};
} // namespace mlir::iree_compiler::IREE::Stream
```
#### Duplicate Executables/Exports
After we gather encoding solvers and encode them in the encodings (for host codes), we should also update the executables and `stream.async.dispatch` ops. Below snippet is how it looks like before the encoding specialization.
As shown in the IR, there are two `stream.async.dispatch` ops. They have the same entry points, but they are run on different devices. The dispatch does not have enough information to generate the relayout codes. Also, they should result in different executables/functions. Otherwise, we could produce device_a layouts on device_b.
```mlir=
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding_solver = #iree_cpu.vmvx_encoding_solver<target_configuration = {ukernels = "none"}>}>
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#device_target_local = #hal.device.target<"local", [#executable_target_vmvx_bytecode_fb]> : !hal.device
#device_target_local1 = #hal.device.target<"local", [#executable_target_embedded_elf_x86_64_]> : !hal.device
stream.executable private @foo_dispatch_0 {
stream.executable.export public @foo_dispatch_0_set_encoding_LHS_DxD workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0, %arg1
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @foo_dispatch_0_set_encoding_LHS_DxD(%arg0: !stream.binding, %arg1: index, %arg2: index, %arg3: !stream.binding) {
%c0 = arith.constant 0 : index
%0 = flow.dispatch.workload.ordinal %arg1, 0 : index
%1 = flow.dispatch.workload.ordinal %arg2, 1 : index
%2 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%0, %1}
%3 = stream.binding.subspan %arg3[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>>>{%0, %1}
%4 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32>>{%0, %1} -> tensor<?x?xf32>
%5 = iree_encoding.set_encoding %4 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>>
flow.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [%0, %1], strides = [1, 1] : tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>>>{%0, %1}
return
}
}
}
// ...
util.func public @foo(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "async func @foo(%input0: tensor<?x?xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}, %input1: tensor<?x?xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<?x?xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}} {
// ...
%14 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @foo_dispatch_0::@foo_dispatch_0_set_encoding_LHS_DxD[%0, %1](%6[%c0 to %2 for %2], %0, %1) : (!stream.resource<*>{%2}, index, index) -> !stream.resource<*>{%13}
// ...
%25 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @foo_dispatch_0::@foo_dispatch_0_set_encoding_LHS_DxD[%0, %1](%22[%c0 to %2 for %2], %0, %1) : (!stream.resource<*>{%2}, index, index) -> !stream.resource<*>{%24}
// ....
}
```
The proposal is cloning the entire executable for each variant, and update the `stream.async.dispatch` ops' entry points. The variants are all about the unique combination of (ExecutableExportOp, execution affinity, array of resource affinties), which means that we need a map to maintain the variants.
Then for each unique variant, we clone the executables and update relevant `stream.async.dispatch` ops.
```mlir=
// export -> [affinity -> array per resource of affinities PVS]
DenseMap<Stream::ExecutableExportOp,
SetVector<std::pair<AffinityAttr, ArrayAttr>>>
exportDispatchSites;
```
*Note: Take a look at the [commit](https://github.com/hanhanW/iree/commit/4f838f57666f6ac81789fc7e8a503374ab4c7737) if you're interested in the implementation. But notedly that the implementation is mostly for prototypes, the final code could change when we start landing them to the main branch. We should be able to just clone export ops and func ops, but the backends are not able to handle the case today. So we clone the entire executables for prototype for now.*
#### Update Executables
At this stage (still in SpecializeEncoding pass scope), we finish the update on Stream tensor ops and stream.async.dispatch ops. We also have the executables cloned. The missing part is how to update the executables that have encoded tensors, see below snippet example. The encodings in executables are not updated. It does not know what the actual layout is because the input could be produced by other devices. What we know at thit stage is the mapping between `stream.bindings` and the affinities. We want to encode such information in the executables; this is the last step of the encoding specialization pass.
```mlir!
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#encoding_lhs = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>
#encoding_rhs = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>
#encoding_acc = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>
func.func @foo_dispatch_6_matmul_DxDxD_f32(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: index, %arg5: !stream.binding) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = flow.dispatch.workload.ordinal %arg2, 0 : index
%1 = flow.dispatch.workload.ordinal %arg3, 1 : index
%2 = flow.dispatch.workload.ordinal %arg4, 2 : index
%3 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_lhs>>{%1, %0}
%4 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_rhs>>{%0, %2}
%5 = stream.binding.subspan %arg5[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #encoding_acc>>{%1, %2}
%6 = flow.dispatch.tensor.load %3, offsets = [0, 0], sizes = [%1, %0], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_lhs>>{%1, %0} -> tensor<?x?xf32, #encoding_lhs>
%7 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%0, %2], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding_rhs>>{%0, %2} -> tensor<?x?xf32, #encoding_rhs>
%8 = tensor.empty(%1, %2) : tensor<?x?xf32, #encoding_acc>
%9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<?x?xf32, #encoding_acc>) -> tensor<?x?xf32, #encoding_acc>
%10 = linalg.matmul
ins(%6, %7 : tensor<?x?xf32, #encoding_lhs>, tensor<?x?xf32, #encoding_rhs>)
outs(%9 : tensor<?x?xf32, #encoding_acc>) -> tensor<?x?xf32, #encoding_acc>
flow.dispatch.tensor.store %10, %5, offsets = [0, 0], sizes = [%1, %2], strides = [1, 1] : tensor<?x?xf32, #encoding_acc> -> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #encoding_acc>>{%1, %2}
return
}
```
Some facts that I learned from Ben during the prototype:
- stream.bindngs are only used (today, but we can add interfaces if needed) by binding subspan ops that return [flow.dispatch.tensor type](https://github.com/iree-org/iree/blob/4aa08f2870b3ecd40808e57888a9d26c1d424c78/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.h#L46-L50).
- In stream it's only safe to touch stream.binding or ops directly using it. So we don't want stream touching flow's dispatch tensor type directly.
- Executable contents besides the stream bindings and stream ops are intended to be opaque and could contain any dialects.
The proposal is having a TypeInterface in the Encoding dialect, and it is attached to the Flow DispatchTensorType. We can walk each stream binding argument and for any op accessing run through the OpResults and do the update (i.e., `result.setType(encodedType.updateEncoding(newEncodingAttr))`). Then we can attach the encoding solver to the flow.dispatch.tensor types, which means that we know what the target encoding is.
With the encoding type interface, we're able to support more general cases. E.g., if a target decides to use something else (e.g., memrefs, my_magical_tensor_type, etc), they could still get the encoding information.
```tablegen=
def IREEEncoding_EncodingTypeInterface :
TypeInterface<"EncodingTypeInterface"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";
let description = [{
Interface used to access/update tensor types with encodings.
}];
let methods = [
InterfaceMethod<
[{
Returns the tensor type with the updated encoding.
}],
/*retTy=*/"::mlir::Type",
/*methodName=*/"updateEncoding",
/*args=*/(ins
"::mlir::iree_compiler::IREE::Encoding::EncodingAttr":$encoding),
/*defaultImplementation=*/[{
return {};
}]
>,
];
}
```
Below is an example IR after updating an executable. Note that only the encodings on flow.dispatch.tensor types have the target list. The rest of computation ops still use the encodings without any update. The expectation is that a codegen pass (, which should be an unified MaterializeEncodingPass in my mind,) use analysis/propagation/etc to take the `flow.dispatch.tensor.load/store` to get the target encoding informations, and materialize the encodings for the rest of operations.
```mlir=
#encoding_solver2 = #iree_cpu.cpu_encoding_solver<target_configuration = {cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+f16c,+fsgsbase,+crc32,+invpcid,+rdpru,+sahf,+lzcnt,+movbe,+mwaitx,+x87,+pku,+evex512,+prfchw,+rdpid,+rdrnd,+rdseed,+sha,+shstk,+vaes,+wbnoinvd,+xsave,+xsavec,+xsaveopt,+xsaves,+fxsr", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 64 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
#encoding_solver3 = #iree_cpu.vmvx_encoding_solver<target_configuration = {ukernels = "all"}>
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32, targets = [#encoding_solver2]>
#encoding1 = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
#encoding2 = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>, targets = [#encoding_solver3]>
#encoding3 = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>, targets = [#encoding_solver2]>
#encoding4 = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
#encoding5 = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>, targets = [#encoding_solver3]>
#encoding6 = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>, targets = [#encoding_solver2]>
#encoding7 = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>>
#encoding8 = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 32, 32, 32>, targets = [#encoding_solver3]>
module {
util.global private @device_a = #device_target_local_0_
util.global private @device_b = #device_target_local_1_
stream.executable private @foo_dispatch_2 {
stream.executable.export public @foo_dispatch_2_matmul_DxDxD_f32 workgroups(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0, %arg1, %arg2, %arg3
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @foo_dispatch_2_matmul_DxDxD_f32(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: !stream.binding) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = flow.dispatch.workload.ordinal %arg2, 0 : index
%1 = flow.dispatch.workload.ordinal %arg3, 1 : index
%2 = flow.dispatch.workload.ordinal %arg4, 2 : index
%3 = flow.dispatch.workload.ordinal %arg5, 3 : index
%4 = stream.binding.subspan %arg0[%c0] : !stream.binding
-> !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding>>{%2, %0}
%5 = stream.binding.subspan %arg1[%c0] : !stream.binding
-> !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding3>>{%1, %3}
%6 = stream.binding.subspan %arg6[%c0] : !stream.binding
-> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #encoding6>>{%2, %3}
%7 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [%2, %0], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding>>{%2, %0}
-> tensor<?x?xf32, #encoding1>
%8 = flow.dispatch.tensor.load %5, offsets = [0, 0], sizes = [%1, %3], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #encoding3>>{%1, %3}
-> tensor<?x?xf32, #encoding4>
%9 = tensor.empty(%2, %3) : tensor<?x?xf32, #encoding7>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?x?xf32, #encoding7>) -> tensor<?x?xf32, #encoding7>
%11 = linalg.matmul ins(%7, %8 : tensor<?x?xf32, #encoding1>, tensor<?x?xf32, #encoding4>) outs(%10 : tensor<?x?xf32, #encoding7>) -> tensor<?x?xf32, #encoding7>
flow.dispatch.tensor.store %11, %6, offsets = [0, 0], sizes = [%2, %3], strides = [1, 1]
: tensor<?x?xf32, #encoding7>
-> !flow.dispatch.tensor<writeonly:tensor<?x?xf32, #encoding6>>{%2, %3}
return
}
}
}
}
```
*Note: Take a look at the [prototype commit](https://github.com/hanhanW/iree/commit/afc5349d25856ed01c344ecd8a2569dda9f06c98) if you're interested in how it looks like in the code.*
### Encode Host/Device Tensors
This section looks easy, but there are some challenges. For the host codes, we can iterate the targets in the encodings to query all potential storage sizes, and pick the maximum one. Nothing special on device codes if the devices have a single executable target.
However, it is hard when the device has several executable targets. We don't know what to do at this moment. And perhaps it is not possible to do data-tiling on it, idk.
The encoding specialization provides the consistency between host encodings and device encodings. This allows us to drop the encodings when we don't know how to handle the encodings at this stage. Notedly that it is very different from the encoding cancellation we have today. What we have today is relying on Codegen to silently drop the encodings, i.e., it calls a dispatch scope pass (i.e., the MaterializeEncodingIntoNop pass) to drop the encodings. The specialization opens the opportunity to the encoding cancellation at Stream level. We can rewrite the types to drop the encodings when a device has multiple executables.
Thus, I'm seeing it as a big improvement for data-tiling + multi-device in IREE.
The encoding cancellation at stream level is not implemented in the prototype yet, but it is doable. The [prototype](https://github.com/hanhanW/iree/commit/81b7f74d15ec13db824de41764f823b73d80da61) only iterates all the potential encoding solvers and choose the maximum size.
### CodeGen Changes
Everything in the IR is the same, except that now the `flow.dispatch.tensor.load/store` ops has the encodings with "targets". We should be able to move all the materialization logics to interface methods, and run patterns to lower ops with encodings for the target. This part is not done in the prototype.
What is done in the [prototype](https://github.com/hanhanW/iree/commit/0bf3be828a6f07f9d94c44f6d3222603714eef5e) is that it checks if all the encoding solvers are identical. If not, it drops the encoding (which should be happening at Stream::EncodeDevicePass ideally but it is just an prototype). If so, we create a fake ExecutableTargetAttr to make it work with exising code. Note again that it is just a prototype. Ideally we should refactor the logics to interface methods and query everything from it. The main goal of this prototype is to prove that we can have an e2e path in the design.
The proposed changes on Codegen side are:
1. Create IREECPU Codegen dialect, like what we have for IREEGPU.
2. Implement the encoding solvers for each backend. Part of them are done for VMVX and CPU in the prototype.
3. Introduce more interface methods and refactor the materialization logics to interface methods.
4. Have an unified MaterializeEncoding pass and be run on all the backends. (The materialization is already happening on all the backends but they use different passes.)
It gives us the support for generating **device_b** layouts on **device_a**; it removes the limitation of GEMM tile sizes selection especially for narrow matrix cases.
### Put Altogether (End-To-End IRs)
Today, it is hard to pass flags for the device configs because MLIR attributes don't really work well in shells with all the `#`'s and such. In this section, we write the MLIR program with hard-coded executable targets and use iree-compile/iree-run-module tools to demonstrate the e2e sample. The sample runs a matmul on an x86 CPU backend and a matmul (with same input data) on the VMVX backend; output the accumulation result of both tensor.
[Here](https://gist.github.com/hanhanW/ac1fa5cfd3459d0063b6429feee68e36) is the selected IR dumps, including the IR dumps before/after SpecializeEncodingPass and CPUMaterializeDeviceEncodingPass.
```bash!
# Compilation
iree-compile --iree-execution-model=async-external ~/matmul.mlir -o /tmp/z.vmfb --iree-global-opt-enable-early-materialization=false
# Execution
iree-run-module --module=/tmp/z.vmfb --function=foo --input=2x3xf32=1,2,3,4,5,6 --input=3x5xf32=1 --device=local-task --device=local-task
# EXEC @foo
# result[0]: hal.buffer_view
# 2x5xf32=[12 12 12 12 12][30 30 30 30 30]
```
```mlir=
// Zen4 CPU
#executable_target_embedded_elf_x86_64_with_encoding_solver = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64",
{cpu = "znver4", cpu_features = "+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+sse4a,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vbmi,+avx512ifma,+avx512vpopcntdq,+avx512vbmi2,+gfni,+vpclmulqdq,+avx512vnni,+avx512bitalg,+avx512bf16,+adx,+clflushopt,+clwb,+clzero,+cx16,+cx8,+f16c,+fsgsbase,+crc32,+invpcid,+rdpru,+sahf,+lzcnt,+movbe,+mwaitx,+x87,+pku,+evex512,+prfchw,+rdpid,+rdrnd,+rdseed,+sha,+shstk,+vaes,+wbnoinvd,+xsave,+xsavec,+xsaveopt,+xsaves,+fxsr",
data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128",
native_vector_size = 64 : i64,
target_triple = "x86_64-unknown-unknown-eabi-elf",
encoding_solver = #iree_cpu.cpu_encoding_solver<>
}>
// VMVX with ukernels enabled.
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding_solver = #iree_cpu.vmvx_encoding_solver<>, ukernels = "all"}>
util.global private @device_a = #hal.device.target<"local", {ordinal = 0 : index}, [
#executable_target_embedded_elf_x86_64_with_encoding_solver
]> : !hal.device
util.global private @device_b = #hal.device.target<"local", {ordinal = 1 : index}, [
#executable_target_vmvx_bytecode_fb
]> : !hal.device
func.func @foo(
%lhs: tensor<?x?xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>},
%rhs: tensor<?x?xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>}) -> (tensor<?x?xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>}) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%M = tensor.dim %lhs, %c0 : tensor<?x?xf32>
%K = tensor.dim %lhs, %c1 : tensor<?x?xf32>
%N = tensor.dim %rhs, %c1 : tensor<?x?xf32>
%cst = arith.constant 0.0 : f32
%init = tensor.empty(%M, %N) : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
%op = linalg.matmul
ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
// Execute matmul on device_a and transfer the result to device_b
%transient_op = flow.tensor.transfer %op : tensor<?x?xf32>{%M, %N} to #hal.device.affinity<@device_b>
// Transfer input data to device_b
%lhsb = flow.tensor.transfer %lhs : tensor<?x?xf32>{%M, %K} to #hal.device.affinity<@device_b>
%rhsb = flow.tensor.transfer %rhs : tensor<?x?xf32>{%K, %N} to #hal.device.affinity<@device_b>
%initb = tensor.empty(%M, %N) : tensor<?x?xf32>
%fillb = linalg.fill ins(%cst : f32) outs(%initb : tensor<?x?xf32>) -> tensor<?x?xf32>
// Execute matmul on device_b and accumulate the result and the result from device_a.
%opb = linalg.matmul
ins(%lhsb, %rhsb : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fillb : tensor<?x?xf32>) -> tensor<?x?xf32>
%add = arith.addf %transient_op, %opb : tensor<?x?xf32>
// Transfer the result from device_b -> device_a.
%result_a = flow.tensor.transfer %add : tensor<?x?xf32>{%M, %N} to #hal.device.affinity<@device_a>
// Return the result on device_a.
func.return %result_a : tensor<?x?xf32>
}
```