[C++]SIMDをクラステンプレート化したかった話。

動機

諸事情あってSIMDとやらの勉強をしている。ちょっとループ処理の高速化を目論んで色々調べていたのだが、私の使い方だと自動ベクトル化が働かないことが分かったので、自前で書かねばといろいろ勉強中なのである1

ところで、SIMDについて色々と調べているものの、どうにも使いにくくて仕方がない。テンプレートと相性が悪いのだ。私の用途では大抵の型情報がテンプレート引数として与えられているが、SIMDレジスタや命令は基本的にテンプレートと無縁なため、極めて食い合わせが悪い。そこで、レジスタや拡張命令を内蔵するクラステンプレートを作り、無理やり中継させてみることにした。 なお、標準ライブラリにはstd::valarrayという数値演算特化の機能があり、場合によってはSIMDによる並列計算を行ってくれるようだが、動的配列は私の用途に合わないし、根本的に機能が足りないので採用しない。

本記事はSIMDの勉強を兼ねた車輪の再発明である。私を除く多くのC++erはこういうテンプレートでラップしたオレオレSIMDライブラリを作ったことがあるのではないかと思う。そうでなくとも、この手のライブラリはおそらくGitHub等にごまんとあるので、速やかにSIMDを導入することが目的なら勉強したての私のコードを参照するより既存のよくメンテナンスされたライブラリを探して使うほうが圧倒的に良い結果となるであろうことを予め断っておく。とはいえ、ちょっとネット上で調べた程度ではこのような日本語記事を見つけられなかったので、SIMD拡張命令を基礎から勉強中で参考となる簡素な例を探している人には、ひょっとしたら役に立つかもしれない。

一応MSVCで動作確認はしているが、いかんせん初心者が足りない知識を妄想で補間しながら作ったものなので間違いが含まれる可能性が非常に高いし、正直に言うと正しいことをしているのかどうか自信がない。こんなものを作ったところで高速化を果たせるのかも分からない。とにかく作って試せばなんか分かるだろ、という非効率な精神で書きなぐった。問題がある場合は容赦なく指摘してもらえると大変ありがたい。

設計方針

double型を8個分、というように型と要素数をテンプレート引数で与えることで、それに応じたレジスタを保持してくれるようなSimdVecクラステンプレートを作る。かつ、各種演算子や数学関数、キャストなどをSIMD命令によって実装する。
つまり、以下のようなことをしたい。

std::array<double, 8> arr1{ 0,  1,  2,  3,  4,  5,  6,  7 };
std::array<float, 8>  arr2{ 8,  9, 10, 11, 12, 13, 14, 15 };

//double、floatなどの型と要素数をテンプレートで与える。
//__m512d、__m256などのレジスタ型は有効となっている拡張命令セットに従って自動的に選択される。
SimdVec<double, 8> vec1 = arr1;
SimdVec<float, 8> vec2 = arr2;

//SIMD命令を用いたfloat<->doubleの暗黙的/明示的キャスト、演算を直感的に行う。
SimdVec<double, 8> vec3 = vec1 * vec2;
SimdVec<float, 8> vec4 = vec1.to<float>() + vec2;

std::array<double, 8> arr3;
std::array<float, 8> arr4;
vec3.store(arr3);
vec4.store(arr4);
//arr3 == { 0, 9, 20, 33, 48, 65, 84, 105 }
//arr4 == { 8, 10, 12, 14, 16, 18, 20, 22 }

今回は特にAVX、AVX2、AVX-512を想定する。AVX非対応のCPUは既に絶滅危惧種であろうから切り捨てる。AVX-512はIntelが半ば切り捨てつつあるような噂を小耳に挟んだが、私の使っているCPUはたまたま対応しているようだったので、こちらも組み込む2。また、あくまで試験的なものなので、今回はfloatとdoubleのみを考えることにする。整数型はまた後日、余裕があれば記事にする。

実装

#ifndef MY_SIMD_VEC_H
#define MY_SIMD_VEC_H

#include <array>
#ifdef _MSC_VER
#include <immintrin.h>
#else
#include <x86intrin.h>
#endif

template <class T, size_t RegLen>
struct Register;
template <> struct Register<float, 128> { using Type = __m128; };
template <> struct Register<float, 256> { using Type = __m256; };
#ifdef __AVX512F__
template <> struct Register<float, 512> { using Type = __m512; };
#endif
template <> struct Register<double, 128> { using Type = __m128d; };
template <> struct Register<double, 256> { using Type = __m256d; };
#ifdef __AVX512F__
template <> struct Register<double, 512> { using Type = __m512d; };
#endif
template <class T, size_t ReqLen>
using Register_t = typename Register<T, ReqLen>::Type;

template <class T_, size_t ReqLen_>
struct SelectRegArray;
template <size_t ReqLen_> requires (ReqLen_ <= 128)
struct SelectRegArray<float, ReqLen_> { using Type = Register_t<float, 128>; static constexpr size_t Len = 128; static constexpr size_t Width = 1; };
template <size_t ReqLen_> requires (ReqLen_ == 256)
struct SelectRegArray<float, ReqLen_> { using Type = Register_t<float, 256>; static constexpr size_t Len = 256; static constexpr size_t Width = 1; };
template <size_t ReqLen_> requires (ReqLen_ >= 512)
struct SelectRegArray<float, ReqLen_>
{
#ifdef __AVX512F__
    using Type = Register_t<float, 512>; static constexpr size_t Len = 512; static constexpr size_t Width = ReqLen_ / 512;
#elif defined(__AVX2__) || defined(__AVX__)
    using Type = Register_t<float, 256>; static constexpr size_t Len = 256; static constexpr size_t Width = ReqLen_ / 256;
#endif
};

template <size_t ReqLen_> requires (ReqLen_ == 256)
struct SelectRegArray<double, ReqLen_> { using Type = Register_t<double, 256>; static constexpr size_t Len = 256; static constexpr size_t Width = 1; };
template <size_t ReqLen_> requires (ReqLen_ >= 512)
struct SelectRegArray<double, ReqLen_>
{
#ifdef __AVX512F__
    using Type = Register_t<double, 512>; static constexpr size_t Len = 512; static constexpr size_t Width = ReqLen_ / 512;
#elif defined(__AVX2__) || defined(__AVX__)
    using Type = Register_t<double, 256>; static constexpr size_t Len = 256; static constexpr size_t Width = ReqLen_ / 256;
#endif
};


template <std::floating_point T, size_t NElm>
    requires (NElm >= 4 && (NElm & (NElm - 1)) == 0)//Nは4以上の2の累乗
struct SimdVec
{
    SimdVec() = default;
    SimdVec(const std::array<T, NElm>& arr)
    {
        if constexpr (std::same_as<T, double>)
        {
            if constexpr (RegLen == 512)
            {
                for (size_t i = 0; i < Width; i++)
                    data[i] = _mm512_loadu_pd(arr.data() + i * 8);
            }
            else
            {
                for (size_t i = 0; i < Width; i++)
                    data[i] = _mm256_loadu_pd(arr.data() + i * 4);
            }
        }
        else
        {
            if constexpr (RegLen == 512)
            {
                for (size_t i = 0; i < Width; i++)
                    data[i] = _mm512_loadu_ps(arr.data() + i * 16);
            }
            else if constexpr (RegLen == 256)
            {
                for (size_t i = 0; i < Width; i++)
                    data[i] = _mm256_loadu_ps(arr.data() + i * 8);
            }
            else
            {
                data[0] = _mm_loadu_ps(arr.data());
            }
        }
    }

    template <class T2>
        requires std::same_as<T2, T>
    SimdVec<T, NElm> to() const
    {
        //変換先が同じ型なら、そのまま自身を返すだけでよい。
        return *this;
    }

    template <class T2>
        requires std::same_as<T2, float>&& std::same_as<T, double>
    SimdVec<float, NElm> to() const
    {
        //自身がSimdVec<double, NElm>で、要素数の等しいSimd<float, NElm>へ変換したい場合の処理。
        SimdVec<float, NElm> res;
        if constexpr (Width == 1)
        {
            //Width==1の場合は、floatもWidth==1でレジスタのビット数が半分になっている。
            if constexpr (RegLen == 512)
                res.data[0] = _mm512_cvtpd_ps(data[0]);
            else
                res.data[0] = _mm256_cvtpd_ps(data[0]);
        }
        else
        {
            //Width>=2の場合、SimdVec<float, NElm>の中身は場合によって異なる。
            constexpr size_t WidthF = Width / 2;
            for (size_t i = 0; i < WidthF; ++i)
            {
                if constexpr (RegLen == 512)
                {
                    __m256 r1 = _mm512_cvtpd_ps(data[2 * i]);
                    __m256 r2 = _mm512_cvtpd_ps(data[2 * i + 1]);
                    res.data[i] = _mm512_insertf32x8(_mm512_castps256_ps512(r1), r2, 1);
                }
                else
                {
                    __m128 r1 = _mm256_cvtpd_ps(data[2 * i]);
                    __m128 r2 = _mm256_cvtpd_ps(data[2 * i + 1]);
                    res.data[i] = _mm256_insertf32x4(_mm256_castps128_ps256(r1), r2, 1);
                }
            }
        }
        return res;
    }

    template <class T2>
        requires std::same_as<T2, double>&& std::same_as<T, float>
    SimdVec<double, NElm> to() const
    {
        //自身がSimdVec<float, NElm>で、要素数の等しいSimd<double, NElm>へ変換したい場合の処理。
        SimdVec<double, NElm> res;

        if constexpr (RegLen == 512)
        {
            //自身が__m512だとすると、SimdVec<double, NElm>も__m512dを2*Width個保持している。
            for (size_t i = 0; i < Width; i++)
            {
                __m512 r = data[i];
                __m256 r1 = _mm512_extractf32x8_ps(r, 0);
                __m256 r2 = _mm512_extractf32x8_ps(r, 1);
                res.data[2 * i] = _mm512_cvtps_pd(r1);
                res.data[2 * i + 1] = _mm512_cvtps_pd(r2);
            }
        }
        else if constexpr (RegLen == 256)
        {
            //256の場合、可能性は2つ。一つは、AVX-512が有効である場合。
            //この場合、確実にWidth==1である。そうでなければ__m512が選択されているはずで、RegLen!=256でなければおかしい。
            //よって、単に__m256->__m512dに変換するだけでよい。
            if constexpr (SimdVec<double, NElm>::RegLen == 512)
            {
                static_assert(Width == 1);
                res.data[0] = _mm512_cvtps_pd(data[0]);
            }
            //もう一つは、AVX-512が有効でないため__m256を2つ以上保持している、すなわちWidth>=2の場合。
            //この場合、SimdVec<double>側も__m256dで保持しているため、
            //長さを合わせつつ変換する必要がある。
            else
            {
                for (size_t i = 0; i < Width; i++)
                {
                    __m128 r1 = _mm256_extractf128_ps(data[i], 0);
                    __m128 r2 = _mm256_extractf128_ps(data[i], 1);
                    res.data[2 * i] = _mm256_cvtps_pd(r1);
                    res.data[2 * i + 1] = _mm256_cvtps_pd(r2);
                }
            }
        }
        else
        {
            //RegLen==128の場合、必ずWidth==1である。
            //ということは、SimdVec<double>側は確実にRegLen==256である。
            static_assert(RegLen == 128 && Width == 1);
            res.data[0] = _mm256_cvtps_pd(data[0]);
        }
        return res;
    }

    void store(std::array<T, NElm>& arr) const
    {
        if constexpr (std::same_as<T, double>)
        {
            if constexpr (RegLen == 512)
            {
                for (size_t i = 0; i < Width; i++)
                    _mm512_storeu_pd(arr.data() + i * 8, data[i]);
            }
            else
            {
                for (size_t i = 0; i < Width; i++)
                    _mm256_storeu_pd(arr.data() + i * 4, data[i]);
            }
        }
        else
        {
            if constexpr (RegLen == 512)
            {
                for (size_t i = 0; i < Width; i++)
                    _mm512_storeu_ps(arr.data() + i * 16, data[i]);
            }
            else if constexpr (RegLen == 256)
            {
                for (size_t i = 0; i < Width; i++)
                    _mm256_storeu_ps(arr.data() + i * 8, data[i]);
            }
            else
            {
                _mm_storeu_ps(arr.data(), data[0]);
            }
        }
    }

    using Type = T;
    static constexpr size_t RequiredLen = sizeof(T) * 8 * NElm;
    static constexpr size_t RegLen = SelectRegArray<T, RequiredLen>::Len;
    using RegType = typename SelectRegArray<T, RequiredLen>::Type;

    static constexpr size_t Width = SelectRegArray<T, RequiredLen>::Width;
    std::array<RegType, Width> data;
};

template <size_t NElm>
SimdVec<double, NElm> operator+(SimdVec<double, NElm> x, SimdVec<double, NElm> y)
{
    SimdVec<double, NElm> res;
    for (size_t i = 0; i < res.Width; i++)
    {
        if constexpr (SimdVec<double, NElm>::RegLen == 512)
            res.data[i] = _mm512_add_pd(x.data[i], y.data[i]);
        else
            res.data[i] = _mm256_add_pd(x.data[i], y.data[i]);
    }
    return res;
}
template <size_t NElm>
SimdVec<float, NElm> operator+(SimdVec<float, NElm> x, SimdVec<float, NElm> y)
{
    SimdVec<float, NElm> res;
    for (size_t i = 0; i < res.Width; i++)
    {
        if constexpr (SimdVec<float, NElm>::RegLen == 512)
            res.data[i] = _mm512_add_ps(x.data[i], y.data[i]);
        else if constexpr (SimdVec<float, NElm>::RegLen == 256)
            res.data[i] = _mm256_add_ps(x.data[i], y.data[i]);
        else
            res.data[i] = _mm_add_ps(x.data[i], y.data[i]);
    }
    return res;
}
template <class T1, class T2, size_t NElm>
    requires (!std::same_as<T1, T2>)
auto operator+(SimdVec<T1, NElm> x, SimdVec<T2, NElm> y)
{
    using T = std::decay_t<decltype(std::declval<T1>() + std::declval<T2>())>;
    return x.to<T>() + y.to<T>();
}

#endif

このクラスの肝はキャストである。データの配置を組み替えるような複雑な使い方はしないので、キャストさえ正しく実装できればあとは対応する命令を呼び出す関数、演算子オーバーロードをひたすら書くだけのやっつけ仕事だ。
非常に大雑把に挙動を説明すると、SelectRegArrayに対してdouble/floatの型と必要なビット数を与えることで、レジスタの型__mAAAとその必要個数Widthを取得し、SimdVecにstd::array<__mAAA, Width> dataというメンバ変数として持たせている。キャストするときは愚直に、レジスタのビット数とWidthの取りうる可能性を全て列挙し分岐させている。

あらゆる分岐を全て書き下すというひどい実装ながらも、とりあえず動くものにはなった。ここに追加で整数型を組み込むことまで考えるとちょっとうんざりするのでもう少し整理しなければならない。また本来はいろいろな演算子や関数に対応させたかったが、ここではコード削減のために加算演算子のみを用意した。SIMDの理解及び動作試験が目的なので今回はまあ良しとする。

以下はテストコード。

#include "simdvec.h"
#include <iostream>

template <class T1, class T2, size_t NElm>
void Test()
{
    std::array<T1, NElm> a;
    std::array<T2, NElm> b;
    for (size_t i = 0; i < NElm; ++i)
    {
        a[i] = (T1)i;
        b[i] = (T2)i;
    }
    SimdVec<T1, NElm> aSimd(a);
    SimdVec<T2, NElm> bSimd(b);
    auto resSimd = aSimd + bSimd;//T1、T2が同じならその型、違うなら大きい方の型(要はdouble)に変換した上で加算する。
    using TRes = decltype(resSimd)::Type;//resSimdの型を取得。floatかdoubleか。
    std::array<TRes, NElm> res;
    resSimd.store(res);
    for (size_t i = 0; i < NElm; ++i)
    {
        std::cout << res[i] << " ";
    }
    std::cout << std::endl;

    auto resSimdf = aSimd.to<float>() + bSimd.to<float>();//floatに変換して加算
    std::array<float, NElm> resf;
    resSimdf.store(resf);
    for (size_t i = 0; i < NElm; ++i)
    {
        std::cout << resf[i] << " ";
    }
    std::cout << std::endl;
}

int main()
{
    Test<double, double, 4>();
    Test<float, float, 4>();
    Test<double, float, 4>();
    Test<float, double, 4>();

    Test<double, double, 8>();
    Test<float, float, 8>();
    Test<double, float, 8>();
    Test<float, double, 8>();

    Test<double, double, 16>();
    Test<float, float, 16>();
    Test<double, float, 16>();
    Test<float, double, 16>();

    return 0;
}

MSVCのコンパイルオプションでAVX2、AVX-512それぞれを有効にした状態でビルド、実行し、どちらも想定通りの結果を得ている。


  1. C++歴ウン年のくせに今の今まで勉強していなかったのかって?情けない話だが、余計なオーバーヘッドのないプログラムを書けば満足できる程度の開発しかしてこなかったので、まるで使い所がなく、勉強したくてもできなかったのだ。
  2. 尤も、AVX-512へ対応しているかどうかにはかなり細かくてややこしい話があるらしく、そのあたりを私はまだきちんと理解できていない。