# [Exp] Tracing EmbeddingBag module source code in pytorch ###### tags: `research-DLRM` [TOC] ## Testing program - This is a simple testing program that demonstrates the functionality of EmbeddingBag operation - The weight matrix (a.k.a the embedding lookup table) is of size 4x2, and is initialized with random weights - The input sequence are [0, 1, 2] - The input sequence should be broken into 2 sentences defined by offset, namely - First sentence: input[0] - Second sentence: input[1:2] - The output results is the "sum" of weight[0] and weight[1:2] ```python= import torch import torch.nn as nn weight = torch.DoubleTensor([[0, 1], [2, 3], [4, 5], [6, 7]]) embedding_sum = nn.EmbeddingBag.from_pretrained(weight, mode='sum') input = torch.LongTensor([0, 2, 3]) offsets = torch.LongTensor([0, 1]) e = embedding_sum(input, offsets) ``` ## EmbeddingBag function call graph In order to understand the EmbeddingBag operation in pytorch, we figure out its functiona calling stack in pytorch source code. Briefly speaking, function calls are broken into python and C++ parts. ```graphviz digraph G { graph [fontsize=10 fontname="Verdana" compound=true]; node [shape=record fontsize=10 fontname="Verdana"]; subgraph cluster_py { node [style=filled]; {rank=same;} l1 [label="nn.EmbeddingBag"]; l2 [label="F.embedding_bag"]; l3 [label="handle_torch_function"]; l1 -> l2 -> l3; label = "Python function call"; color=blue; } subgraph cluster_cpp { node [style=filled]; // EmbeddingBag.cpp emb_bag_l1 [label="_embedding_bag_forward_only_cpu \n (EmbeddingBag.cpp)"]; emb_bag_l2 [label="_embedding_bag_cpu_impl \n (EmbeddingBag.cpp)"]; emb_bag_l3 [label="_embedding_bag_cpu_impl_out \n (EmbeddingBag.cpp)"]; // index_select_add subgraph cluster_float { node [style=empty]; embbag_float [label="index_select_add (float)"]; embbag_by_caffe [label="caffe2::EmbeddingLookupIdx()"]; embbag_by_aten [label="at::native::cpublas::axpy\<float\>()"]; decision [label="is_fast_path_index_select"] embbag_float -> decision; decision -> embbag_by_caffe; decision -> embbag_by_aten; embbag_by_caffe -> caffe_native; embbag_by_caffe -> caffe_avx2; subgraph cluster_caffe { node [style=filled]; caffe_native [label="EmbeddingLookupGenericSlowIdx"]; caffe_avx2 [label="EmbeddingLookupIdx_int32_t_float_float__avx2_fma"]; label = "EmbeddingLookupIdx from caffe"; color=orange; } label = "EmbeddingBag with float (same flow for half)"; color=red; } subgraph cluster_neither { node [style=empty]; embbag_neither [label="index_select_add (!float & !half)"]; embbag_neither_by_aten [label="at::native::cpublas::axpy\<data_t\>()"]; embbag_neither -> embbag_neither_by_aten; label = "EmbeddingBag with neither"; color=red; } // Define connections emb_bag_l1 -> emb_bag_l2 -> emb_bag_l3; emb_bag_l3 -> embbag_float; emb_bag_l3 -> embbag_neither; label = "C++ function call"; color=blue; } // Edges between nodes render fine // Edges that directly connect one cluster to another l3 -> emb_bag_l1; } ``` ## Python function call The underlying implementation of pytorch is written in C++ while the python language merely serves as an interface. #### [nn.EmbeddingBag](https://github.com/WeiCheng14159/pytorch/blob/master/torch/nn/modules/sparse.py#L221) **[nn.EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html)** class is a subclass of **[torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)** class, which is the base class for all neural network modules in pytorch. #### [F.embedding_bag](https://github.com/WeiCheng14159/pytorch/blob/master/torch/nn/functional.py#L2200) This is the function that actually do the work. #### [handle_torch_function](https://github.com/WeiCheng14159/pytorch/blob/master/torch/overrides.py#L1460) This python function call handles all the API calls to the underlying C++ implementation. ## C++ function call ### [EmbeddingBag.cpp](https://github.com/WeiCheng14159/pytorch/blob/master/aten/src/ATen/native/EmbeddingBag.cpp) This is the C++ files that is responsible for all EmbeddingBag-related computation. Pytorch is a versatile library that handles heterogeneous computing workloads; thus, this file contains code for CPU/GPU/AVX platforms. To simplify the case, we only focus on the CPU implementation of EmbeddingBag for now. In the following section, we will go through the C++ function calls used in this file. #### [_embedding_bag_cpu_impl_out](https://github.com/WeiCheng14159/pytorch/blob/c18a18cae94b4b8bb8541319c2f39b7e67b7621d/aten/src/ATen/native/EmbeddingBag.cpp#L1026) This function is the first entry point in EmbeddingBag.cpp file. This function will inspect the data type (float/half/double) and divert program flow to the corresponding handling functions. Also, this function calls different subroutines based the parameters passed to EmbeddingBag. For instance, different subroutines will be called for **MODE_MAX** and **MODE_SUM**. :::spoiler Code ```cpp= ... if (mode == MODE_MEAN || mode == MODE_SUM) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_no_grad_cpu_out", [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() { AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_no_grad_cpu_out", [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() { if (per_sample_weights.has_value() && per_sample_weights.value().defined()) { TORCH_INTERNAL_ASSERT(mode == MODE_SUM); index_select_scale_add<scalar_t, index_t>( indices, offset2bag, per_sample_weights.value(), weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache); } else { index_select_add<scalar_t, index_t>(indices, offset2bag, weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache); } }); }); ... ``` ::: #### [index_select_add](https://github.com/WeiCheng14159/pytorch/blob/master/aten/src/ATen/native/EmbeddingBag.cpp#L1026) In case the floating point data type is used in EmbeddingBag, it is handled by index_select_add function. Dpends on certain criteria, pytorch will offload the computation to **Caffe library** or **ATen library**. :::spoiler Code ```cpp= template<typename data_t, typename index_t> typename std::enable_if<std::is_same<data_t, float>::value, void>::type index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, Tensor &output, const Tensor& offsets, bool include_last_offset, Tensor &bag_size, index_t padding_idx, _EmbeddingBagKernelCache* fbgemm_kernel_cache) {...} ``` ::: The index_select_add function is separated into two implementation: **caffe** impl. and **axpy** impl. The caffe/axpy decision is made by the **"is_fast_path_index_select(src, output, padding_idx)"** function - Caffe library (caffe impl.) - The caffe **EmbeddingLookupIdx** function is used here and this function is wrapped up by the **at::parallel_for** operator for parallel execution. :::spoiler Code ```cpp= ... caffe2::EmbeddingLookupIdx( /*block_size=*/ddim, /*output_size=*/end_idx - start_idx, /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], /*data_size=*/src.size(0), /*input=*/src_data, /*indices=*/select_indices_data + offsets_data[start_idx], /*offsets=*/offsets_data + start_idx, /*weights=*/nullptr, /*scale_bias=*/nullptr, /*normalize_by_lengths=*/false, /*out=*/output_data + start_idx * ddim); ... ``` ::: - ATen library (axpy impl.) - The **at::native::cpublas::axpy\<float\>(...)** function is called because the EmbeddingBag operation can be modeled by [AXPY](https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms) BLAS operation. :::spoiler Code ```cpp= ... AT_ASSERT(select_indices.numel() == add_indices.numel()); auto* src_data = src.data_ptr<float>(); auto* add_indices_data = add_indices.data_ptr<index_t>(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) index_t* bag_size_data = nullptr; if (bag_size.defined()) { bag_size_data = bag_size.data_ptr<index_t>(); } auto vocab_size = src.size(0); auto src_stride0 = src.strides()[0]; auto src_stride1 = src.strides()[1]; auto output_stride0 = output.strides()[0]; auto output_stride1 = output.strides()[1]; auto numel = add_indices.numel(); for (const auto i : c10::irange(numel)) { // We can skip indices equal to padding_idx so they are not included in // the reduction auto idx = select_indices_data[i]; TORCH_CHECK( idx >= 0 && idx < vocab_size, "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", idx); if (idx != padding_idx) { at::native::cpublas::axpy<float>( ddim, 1, src_data + src_stride0 * idx, src_stride1, output_data + output_stride0 * add_indices_data[i], output_stride1); } else if (bag_size.defined()) { // Decrement bag_size to reflect that the index is padded // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) bag_size_data[add_indices_data[i]]--; } } ... }; ``` ::: #### [is_fast_path_index_select](https://github.com/WeiCheng14159/pytorch/blob/c18a18cae94b4b8bb8541319c2f39b7e67b7621d/aten/src/ATen/native/EmbeddingBag.cpp#L62) Depends on the condition (specified by **"is_fast_path_index_select"** function), the EmbeddingBag operator will switch between the "caffe impl." and "axpy impl." ```cpp // Determines if we can use a fast implementation for index_select_add, which // is only applicable if special conditions are met template<typename index_t> bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) { return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && padding_idx < static_cast<index_t>(0); } ``` The reason why there are different impl. of EmbeddingBag is to increase throughput as shown in this commit message [github](https://github.com/pytorch/pytorch/commit/d17c22d024467b7185e33c4652b44739f67965be). According to the author the throughput is increased from ~8 GB/s to ~14 GB/s. However, this separation of implementation is only applied to single precision floating point data type. Other data type (i.e. double, int) will solely depend on the corresponding axpy implementations. ## Caffe EmbeddingLookupIdx impl. ### [embedding_lookup_idx.cc](https://github.com/WeiCheng14159/pytorch/blob/master/caffe2/perfkernels/embedding_lookup.cc) This function is the entry point for Caffe library. The caffe implementation is defined by the **EMBEDDING_IDX_SPECIALIZATION** macro, which specifies the data type that will be used in embedding lookup (i.e. half, quantized 8 bit etc.). This function will also divert embedding lookup computation to multiple platforms (i.e. CPU, CUDA, AVX etc.) - EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false); ```cpp #define EMBEDDING_IDX_SPECIALIZATION( \ IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \ ``` We will only focus on the CPU implementation here, and the following is the most naive implementation (executed on CPU) for embedding lookup. - Naive CPU implementation :::spoiler Code ```cpp= template <...> static bool EmbeddingLookupGenericSlowIdx(...) { int64_t current = 0; for (const auto m : c10::irange(output_size)) { memset(out, 0, sizeof(OutType) * block_size); if (current != offsets[m] - offsets[0]) { return false; } int64_t start_offset = offsets[m]; int64_t end_offset = offsets[m + 1]; int64_t length = end_offset - start_offset; for (const auto i : c10::irange(start_offset, end_offset)) { int64_t idx = indices[current]; if (idx < 0 || idx >= data_size) { return false; } #ifdef __GNUC__ if (current + 1 < index_size) { __builtin_prefetch(input + block_size * indices[current + 1], 0, 1); } #endif // __GNUC__ float w = 1.f, b = 0.f; if (weights) { w = weights[IS_WEIGHT_POSITIONAL ? i - start_offset : current]; } if (scale_bias) { b = w * scale_bias[2 * indices[current] + 1]; w = w * scale_bias[2 * indices[current]]; } for (const auto j : c10::irange(block_size)) { out[j] += w * input[block_size * indices[current] + j] + b; } ++current; } if (normalize_by_lengths && length) { float scale = 1.f / length; for (const auto j : c10::irange(block_size)) { out[j] *= scale; } } out += block_size; } return current == index_size; } ``` ::: - Fast CPU implementation There are faster implementations on CPU that utilize the AVX2 SIMD instruction on CPUs. However, the AVX implementation consists of machine level assembly code and is very hard to understand. The code is listed here for your information. - [embedding_lookup_idx_avx2.cc]( https://github.com/WeiCheng14159/pytorch/blob/master/caffe2/perfkernels/embedding_lookup_avx2.cc) Notice that the AVX code above is generated by the following python script automatically. - [hp_emblookup_codegen.py](https://github.com/WeiCheng14159/pytorch/blob/master/caffe2/perfkernels/hp_emblookup_codegen.py) ## ATen AXPY impl. The AXPY operation in ATen library is merely computed by "for loop" :::spoiler ```cpp= ... for (const auto i : c10::irange(numel)) { // We can skip indices equal to padding_idx so they are not included in // the reduction auto idx = select_indices_data[i]; TORCH_CHECK( idx >= 0 && idx < vocab_size, "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", idx); if (idx != padding_idx) { at::native::cpublas::axpy<float>( ddim, 1, src_data + src_stride0 * idx, src_stride1, output_data + output_stride0 * add_indices_data[i], output_stride1); } else if (bag_size.defined()) { // Decrement bag_size to reflect that the index is padded // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) bag_size_data[add_indices_data[i]]--; } ... }; ``` ::: ### [CPUBlash.h](https://github.com/WeiCheng14159/pytorch/blob/c18a18cae94b4b8bb8541319c2f39b7e67b7621d/aten/src/ATen/native/CPUBlas.h#L128) This is the entry point for the AXPY function in ATen library. This function is expected to computed the following: $$\vec{y} = \vec{y}+a*\vec{x},\text{where a is a scalar}$$ ```cpp= template<typename scalar_t> void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){ if(n == 1) { incx = 1; incy = 1; } axpy_stub( kCPU, c10::CppTypeToScalarType<scalar_t>::value, n, a, x, incx, y, incy); } ``` - n: Number of element that will be selected from a vector. - a: Vector will be scaled by this factor. - x: This is the base address for vector $\vec{x}$ - incx: The next element in $\vec{x}$ will be incx elements away. - y: This is the base address for vector $\vec{y}$ - incy: The next element in $\vec{y}$ will be incy elements away. ### [BlasKernel.cpp](https://github.com/WeiCheng14159/pytorch/blob/c18a18cae94b4b8bb8541319c2f39b7e67b7621d/aten/src/ATen/native/cpu/BlasKernel.cpp#L193) The function **cpublas_axpy_impl(...)** is the actual function that computes AXPY on CPU. ```cpp= void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){ if (type == at::kBool) { auto a = _a.to<bool>(); auto x = static_cast<const bool *>(_x); auto y = static_cast<bool *>(_y); int64_t i; for(i = 0; i < n; i++) y[i*incy] |= a & x[i*incx]; } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16, type, "cpublas_axpy_impl", [&] { using opmath_t = at::opmath_type<scalar_t>; auto a = _a.to<opmath_t>(); auto x = static_cast<const scalar_t *>(_x); auto y = static_cast<scalar_t *>(_y); int64_t i; for(i = 0; i < n; i++) y[i*incy] += a*x[i*incx]; }); } } ``` ## Appendix :::info ### gdb skills - Set breakpoints in all functions in a file ``` (gdb) rbreak file.cpp:. ``` :::