/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#define ESIMD_UNROLL _Pragma("unroll")

// masked to avoid out-of-bounds read
template <typename ty, uint32_t size>
ESIMD_INLINE esimd::simd<ty, size> gather_read(ty *buf, uint32_t bufsize) {
    esimd::simd<ty, size> v;

    const esimd::simd<uint64_t, size> offsets(0, 1); // this constructor is undocumented??
    const esimd::simd<uint64_t, size> offsetBytes = offsets * sizeof(ty);

    if (size == bufsize) {
        v = esimd::gather<ty, size>(buf, offsetBytes);
    }
    else {
        const esimd::simd_mask<size> mask = esimd::unpack_mask<size>((1 << bufsize) - 1);
        v = esimd::gather<ty, size>(buf, offsetBytes, mask);
    }
    return v;
}

// masked to avoid out-of-bounds write
template <typename ty, uint32_t size>
ESIMD_INLINE void scatter_write(ty *buf, uint32_t bufsize, esimd::simd<ty, size> v) {

    const esimd::simd<uint64_t, size> offsets(0, 1); // this constructor is undocumented??
    const esimd::simd<uint64_t, size> offsetBytes = offsets * sizeof(ty);

    if (size == bufsize) {
        esimd::scatter<ty, size>(buf, offsetBytes, v);
    }
    else {
        const esimd::simd_mask<size> mask = esimd::unpack_mask<size>((1 << bufsize) - 1);
        esimd::scatter<ty, size>(buf, offsetBytes, v, mask);
    }
}

#define BASE_SZ 32

ESIMD_INLINE void
bitonic_exchange8(esimd::simd<int32_t, BASE_SZ> &A,
                  esimd::simd<int32_t, BASE_SZ> &B,
                  esimd::simd<double, BASE_SZ>& companionA, /* in */
                  esimd::simd<double, BASE_SZ>& companionB, /* out */
                  esimd::simd_mask<32> flip) {
  ESIMD_UNROLL
  for (int i = 0; i < BASE_SZ; i += 32) {
    B.select<8, 1>(i) = A.select<8, 1>(i + 8);
    B.select<8, 1>(i + 8) = A.select<8, 1>(i);
    B.select<8, 1>(i + 16) = A.select<8, 1>(i + 24);
    B.select<8, 1>(i + 24) = A.select<8, 1>(i + 16);

    companionB.select<8, 1>(i) = companionA.select<8, 1>(i + 8);
    companionB.select<8, 1>(i + 8) = companionA.select<8, 1>(i);
    companionB.select<8, 1>(i + 16) = companionA.select<8, 1>(i + 24);
    companionB.select<8, 1>(i + 24) = companionA.select<8, 1>(i + 16);

    auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip;
    B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);

    companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
  }
}

ESIMD_INLINE void
bitonic_exchange4(esimd::simd<int32_t, BASE_SZ> A,
                  esimd::simd<int32_t, BASE_SZ> &B,
                  esimd::simd<double, BASE_SZ>& companionA, /* in */
                  esimd::simd<double, BASE_SZ>& companionB, /* out */
                  esimd::simd_mask<32> flip) {
    ESIMD_UNROLL
    for (int i = 0; i < BASE_SZ; i += 32) {
        auto MA = A.select<32, 1>(i).bit_cast_view<int32_t, 4, 8>();
        auto MB = B.select<32, 1>(i).bit_cast_view<int32_t, 4, 8>();
        MB.select<4, 1, 4, 1>(0, 0) = MA.select<4, 1, 4, 1>(0, 4);
        MB.select<4, 1, 4, 1>(0, 4) = MA.select<4, 1, 4, 1>(0, 0);

        auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip;
        B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);

        auto McompanionB = companionB.select<32, 1>(i).bit_cast_view<double, 4, 8>();
        auto McompanionA = companionA.select<32, 1>(i).bit_cast_view<double, 4, 8>();
        McompanionB.select<4, 1, 4, 1>(0, 0) = McompanionA.select<4, 1, 4, 1>(0, 4);
        McompanionB.select<4, 1, 4, 1>(0, 4) = McompanionA.select<4, 1, 4, 1>(0, 0);

        companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
    }
}

// largest changes for double precision workaround
ESIMD_INLINE void
bitonic_exchange2(esimd::simd<int32_t, BASE_SZ> A,
                  esimd::simd<int32_t, BASE_SZ> &B,
                  esimd::simd<double, BASE_SZ>& companionA, /* in */
                  esimd::simd<double, BASE_SZ>& companionB, /* out */
                  esimd::simd_mask<32> flip) {

    ESIMD_UNROLL
    for (int i = 0; i < BASE_SZ; i += 32) {
        B.select<8, 4>(i) = A.select<8, 4>(i + 2);
        B.select<8, 4>(i + 1) = A.select<8, 4>(i + 3);
        B.select<8, 4>(i + 2) = A.select<8, 4>(i);
        B.select<8, 4>(i + 3) = A.select<8, 4>(i + 1);

        auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip;
        B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);

        // bit_cast to float to work around GSD-5451
        auto cBfloat = companionB.select<32, 1>(i).bit_cast_view<float>();
        auto cAfloat = companionA.select<32, 1>(i).bit_cast_view<float>();
        // companionB.select<8, 4>(i) = companionA.select<8, 4>(i + 2);
        cBfloat.select<8, 8>(0) = cAfloat.select<8, 8>(0 + 4);
        cBfloat.select<8, 8>(0 + 1) = cAfloat.select<8, 8>(0 + 4 + 1);
        // companionB.select<8, 4>(i + 1) = companionA.select<8, 4>(i + 3);
        cBfloat.select<8, 8>(0 + 2) = cAfloat.select<8, 8>(0 + 6);
        cBfloat.select<8, 8>(0 + 2 + 1) = cAfloat.select<8, 8>(0 + 6 + 1);
        // companionB.select<8, 4>(i + 2) = companionA.select<8, 4>(i);
        cBfloat.select<8, 8>(0 + 4) = cAfloat.select<8, 8>(0);
        cBfloat.select<8, 8>(0 + 4 + 1) = cAfloat.select<8, 8>(0 + 1);
        // companionB.select<8, 4>(i + 3) = companionA.select<8, 4>(i + 1);
        cBfloat.select<8, 8>(0 + 6) = cAfloat.select<8, 8>(0 + 2);
        cBfloat.select<8, 8>(0 + 6 + 1) = cAfloat.select<8, 8>(0 + 2 + 1);
        // companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
        cBfloat.select<32, 2>(0).merge(cAfloat.select<32, 2>(0), mask);
        cBfloat.select<32, 2>(0 + 1).merge(cAfloat.select<32, 2>(0 + 1), mask);

    }
}

ESIMD_INLINE void
bitonic_exchange1(esimd::simd<int32_t, BASE_SZ> A, /* in */
                  esimd::simd<int32_t, BASE_SZ>& B, /* out */
                  esimd::simd<double, BASE_SZ>& companionA, /* in */
                  esimd::simd<double, BASE_SZ>& companionB, /* out */
                  esimd::simd_mask<32> flip) {

  ESIMD_UNROLL
  // each thread is handling 256-element chunk. Each iteration
  // compares and swaps two 32 elements
  for (int i = 0; i < BASE_SZ; i += 32) {
    auto T = B.select<32, 1>(i);
    auto Tcompanion = companionB.select<32, 1>(i);

    // The first step is to select A's odd-position elements,
    // indicated by A.select<16,2>(i), which selects 16 elements
    // with stride 2 starting from location A[i], and copies
    // the selected elements to B[i] location with stride 2.
    T.select<16, 2>(0) = A.select<16, 2>(i + 1);
    Tcompanion.select<16, 2>(0) = companionA.select<16, 2>(i + 1);

    // The next step selects 16 even-position elements starting
    // from A[i+1] and copies them over to B's odd positions
    // starting at B[i+1]. After the first two steps,
    // all even-odd pair elements are swapped.
    T.select<16, 2>(1) = A.select<16, 2>(i);
    Tcompanion.select<16, 2>(1) = companionA.select<16, 2>(i);

    // The final step determines if the swapped pairs in B are
    // the desired order and should be preserved. If not, their values
    // are overwritten by their corresponding original values
    // (before swapping). The comparisons determine which elements
    // in B already meet the sorting order requirement and which are not.
    // Consider the first two elements of A & B, B[0] and B[1] is
    // the swap of A[0] and A[1]. Element-wise < comparison tells
    // that A[0] < B[0], i.e., A[0] < A[1]. Since the desired sorting
    // order is A[0] < A[1], however, we already swap the two values
    // as we copy A to B. The XOR operation is to set the condition to
    // indicate which elements in original vector A have the right sorting
    // order. Those elements are then merged from A to B based on their
    // corresponding conditions. Consider B[2] and B[3] in this case.
    // The order already satisfies the sorting order. The flip vector
    // passed to this stage is [0,1,1,0,0,1,1,0]. The flip bit of B[2]
    // resets the condition so that the later merge operation preserves
    // B[2] and won't copy from A[2].
    auto mask = flip ^ (A.select<32, 1>(i) < T);
    T.merge(A.select<32, 1>(i), mask);
    Tcompanion.merge(companionA.select<32, 1>(i), mask);
  }
}


ESIMD_INLINE void bitonic_merge(uint32_t offset,
                                esimd::simd<int32_t, BASE_SZ> &A,
                                esimd::simd<double, BASE_SZ> &companionA,
                                uint32_t n, uint32_t m) {
  // dist is the stride distance for compare-and-swap
  uint32_t dist = 1 << n;
  // number of exchange passes we need
  // this loop handles stride distance 128 down to 16. Each iteration
  // the distance is halved. Due to data access patterns of stride
  // 8, 4, 2 and 1 are within one GRF, those stride distance are handled
  // by custom tailored code to take advantage of register regioning.
  for (int k = 0; k < n - 3; k++, dist >>= 1) {
    // Each HW thread process 256 data elements. For a given stride
    // distance S, 256 elements are divided into 256/(2*S) groups.
    // within each group, two elements with distance S apart are
    // compared and swapped based on sorting direction.
    // This loop basically iterates through each group.
    for (int i = 0; i < BASE_SZ; i += dist * 2) {
      // Every bitonic stage, we need to maintain bitonic sorting order.
      // Namely, data are sorted into alternating ascending and descending
      // fashion. As show in Figure 9, the light blue background regions
      // are in ascending order, the light green background regions in
      // descending order. Whether data are in ascending or descending
      // regions depends on their position and the current bitonic stage
      // they are in. "offset+i" the position. For stage m, data of
      // chunks of 1<<(m+1) elements in all the stride steps have the
      // same order.
      bool dir_up = (((offset + i) >> (m + 1)) & 1) == 0;
      // each iteration swap 2 16-element chunks
      for (int j = 0; j < (dist >> 4); j++) {
        esimd::simd<int32_t, 16> T = A.select<16, 1>(i + j * 16);
        esimd::simd<float, 16> companionT = companionA.select<16, 1>(i + j * 16);
        auto T1 = A.select<16, 1>(i + j * 16);
        auto T2 = A.select<16, 1>(i + j * 16 + dist);
        auto companionT1 = companionA.select<16, 1>(i + j * 16);
        auto companionT2 = companionA.select<16, 1>(i + j * 16 + dist);
        if (dir_up) {
            auto mask = T2 < T1;
            T1.merge(T2, mask);
            companionT1.merge(companionT2, mask);
            auto mask2 = T > T2;
            T2.merge(T, mask2);
            companionT2.merge(companionT, mask2);
        } else {
            auto mask = T2 > T1;
            T1.merge(T2, mask);
            companionT1.merge(companionT2, mask);
            auto mask2 = T < T2;
            T2.merge(T, mask2);
            companionT2.merge(companionT, mask2);
        }
      }
    }
  }

  // Stride 1, 2, 4, and 8 in bitonic_merge are custom tailored to
  // take advantage of register regioning. The implementation is
  // similar to bitonic_exchange{1,2,4,8}.

  // exchange 8
  esimd::simd_mask<32> flip13 = esimd::unpack_mask<32>(0xff00ff00); //(init_mask13);
  esimd::simd_mask<32> flip14 = esimd::unpack_mask<32>(0x00ff00ff); //(init_mask14);
  esimd::simd<int32_t, BASE_SZ> B;
  esimd::simd<double, BASE_SZ> companionB;
  for (int i = 0; i < BASE_SZ; i += 32) {
    B.select<8, 1>(i) = A.select<8, 1>(i + 8);
    B.select<8, 1>(i + 8) = A.select<8, 1>(i);
    B.select<8, 1>(i + 16) = A.select<8, 1>(i + 24);
    B.select<8, 1>(i + 24) = A.select<8, 1>(i + 16);
    companionB.select<8, 1>(i) = companionA.select<8, 1>(i + 8);
    companionB.select<8, 1>(i + 8) = companionA.select<8, 1>(i);
    companionB.select<8, 1>(i + 16) = companionA.select<8, 1>(i + 24);
    companionB.select<8, 1>(i + 24) = companionA.select<8, 1>(i + 16);
    bool dir_up = (((offset + i) >> (m + 1)) & 1) == 0;
    if (dir_up) {
        auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip13;
        B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);
        companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
    }
    else {
        auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip14;
        B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);
        companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
    }
  }

  // exchange 4
  esimd::simd_mask<32> flip15 = esimd::unpack_mask<32>(0xf0f0f0f0); //(init_mask15);
  esimd::simd_mask<32> flip16 = esimd::unpack_mask<32>(0x0f0f0f0f); //(init_mask16);
  ESIMD_UNROLL
  for (int i = 0; i < BASE_SZ; i += 32) {
    auto MA = A.select<32, 1>(i).bit_cast_view<int32_t, 4, 8>();
    auto MB = B.select<32, 1>(i).bit_cast_view<int32_t, 4, 8>();
    MA.select<4, 1, 4, 1>(0, 0) = MB.select<4, 1, 4, 1>(0, 4);
    MA.select<4, 1, 4, 1>(0, 4) = MB.select<4, 1, 4, 1>(0, 0);
    auto companionMA = companionA.select<32, 1>(i).bit_cast_view<double, 4, 8>();
    auto companionMB = companionB.select<32, 1>(i).bit_cast_view<double, 4, 8>();
    companionMA.select<4, 1, 4, 1>(0, 0) = companionMB.select<4, 1, 4, 1>(0, 4);
    companionMA.select<4, 1, 4, 1>(0, 4) = companionMB.select<4, 1, 4, 1>(0, 0);
    bool dir_up = (((offset + i) >> (m + 1)) & 1) == 0;
    if (dir_up) {
        auto mask = (B.select<32, 1>(i) < A.select<32, 1>(i)) ^ flip15;
        A.select<32, 1>(i).merge(B.select<32, 1>(i), mask);
        companionA.select<32, 1>(i).merge(companionB.select<32, 1>(i), mask);
    }
    else {
        auto mask = (B.select<32, 1>(i) < A.select<32, 1>(i)) ^ flip16;
        A.select<32, 1>(i).merge(B.select<32, 1>(i), mask);
        companionA.select<32, 1>(i).merge(companionB.select<32, 1>(i), mask);
    }
  }

  // exchange 2
  esimd::simd_mask<32> flip17 = esimd::unpack_mask<32>(0xcccccccc); //(init_mask17);
  esimd::simd_mask<32> flip18 = esimd::unpack_mask<32>(0x33333333); //(init_mask18);
  ESIMD_UNROLL
  for (int i = 0; i < BASE_SZ; i += 32) {
    B.select<8, 4>(i) = A.select<8, 4>(i + 2);
    B.select<8, 4>(i + 1) = A.select<8, 4>(i + 3);
    B.select<8, 4>(i + 2) = A.select<8, 4>(i);
    B.select<8, 4>(i + 3) = A.select<8, 4>(i + 1);

    // bit_cast to float to avoid larger strides in double, GSD-5451
    auto cBfloat = companionB.select<32, 1>(i).bit_cast_view<float>();
    auto cAfloat = companionA.select<32, 1>(i).bit_cast_view<float>();
    // companionB.select<8, 4>(i) = companionA.select<8, 4>(i + 2);
    cBfloat.select<8, 8>(0) = cAfloat.select<8, 8>(0 + 4);
    cBfloat.select<8, 8>(0 + 1) = cAfloat.select<8, 8>(0 + 4 + 1);
    // companionB.select<8, 4>(i + 1) = companionA.select<8, 4>(i + 3);
    cBfloat.select<8, 8>(0 + 2) = cAfloat.select<8, 8>(0 + 6);
    cBfloat.select<8, 8>(0 + 2 + 1) = cAfloat.select<8, 8>(0 + 6 + 1);
    // companionB.select<8, 4>(i + 2) = companionA.select<8, 4>(i);
    cBfloat.select<8, 8>(0 + 4) = cAfloat.select<8, 8>(0);
    cBfloat.select<8, 8>(0 + 4 + 1) = cAfloat.select<8, 8>(0 + 1);
    // companionB.select<8, 4>(i + 3) = companionA.select<8, 4>(i + 1);
    cBfloat.select<8, 8>(0 + 6) = cAfloat.select<8, 8>(0 + 2);
    cBfloat.select<8, 8>(0 + 6 + 1) = cAfloat.select<8, 8>(0 + 2 + 1);

    bool dir_up = (((offset + i) >> (m + 1)) & 1) == 0;
    if (dir_up) {
        auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip17;
        B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);
        // companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
        cBfloat.select<32, 2>(0).merge(cAfloat.select<32, 2>(0), mask);
        cBfloat.select<32, 2>(0 + 1).merge(cAfloat.select<32, 2>(0 + 1), mask);        
    }
    else {
        auto mask = (A.select<32, 1>(i) < B.select<32, 1>(i)) ^ flip18;
        B.select<32, 1>(i).merge(A.select<32, 1>(i), mask);
        companionB.select<32, 1>(i).merge(companionA.select<32, 1>(i), mask);
        cBfloat.select<32, 2>(0).merge(cAfloat.select<32, 2>(0), mask);
        cBfloat.select<32, 2>(0 + 1).merge(cAfloat.select<32, 2>(0 + 1), mask);        
    }
  }

  // exchange 1
  esimd::simd_mask<32> flip19 = esimd::unpack_mask<32>(0xaaaaaaaa); //(init_mask19);
  esimd::simd_mask<32> flip20 = esimd::unpack_mask<32>(0x55555555); //(init_mask20);
  ESIMD_UNROLL
  // Each iteration compares and swaps 2 32-element chunks
  for (int i = 0; i < BASE_SZ; i += 32) {
    // As aforementioned in bitonic_exchange1.
    // switch even and odd elements of B and put them in A.
    auto T = A.select<32, 1>(i);
    T.select<16, 2>(0) = B.select<16, 2>(i + 1);
    T.select<16, 2>(1) = B.select<16, 2>(i);
    auto companionT = companionA.select<32, 1>(i);
    companionT.select<16, 2>(0) = companionB.select<16, 2>(i + 1);
    companionT.select<16, 2>(1) = companionB.select<16, 2>(i);
    // determine whether data are in ascending or descending regions
    // depends on their position and the current bitonic stage
    // they are in. "offset+i" is the position. For stage m,
    // data of chunks of 1<<(m+1) elements in all the stride steps
    // have the same order. For instance, in stage 4, all first 32 elements
    // are in ascending order and the next 32 elements are in descending
    // order. "&1" determines the alternating ascending and descending order.
    bool dir_up = (((offset + i) >> (m + 1)) & 1) == 0;
    // choose flip vector based on the direction (ascending or descending).
    // Compare and swap
    if (dir_up) {
        auto mask = (B.select<32, 1>(i) < T) ^ flip19;
        A.select<32, 1>(i).merge(B.select<32, 1>(i), mask);
        companionA.select<32, 1>(i).merge(companionB.select<32, 1>(i), mask);
    }
    else {
        auto mask = (B.select<32, 1>(i) < T) ^ flip20;
        A.select<32, 1>(i).merge(B.select<32, 1>(i), mask);
        companionA.select<32, 1>(i).merge(companionB.select<32, 1>(i), mask);
    }
  }
}

template <int MAXROWSIZE>
ESIMD_INLINE void bitonic_exchange_simd(esimd::simd<int32_t, MAXROWSIZE> &A,
                                        esimd::simd<double, MAXROWSIZE> &companionA,
                                        int32_t length) {
  esimd::simd<int32_t, MAXROWSIZE> B;
  esimd::simd<double, MAXROWSIZE> companionB;
  if (length < MAXROWSIZE) {
      // pad simd vector with int32_t::max()
      uint32_t packed_mask = (1 << (MAXROWSIZE - length)) - 1; // opposite of the mask I want
      packed_mask = (packed_mask << length);
      const esimd::simd_mask<MAXROWSIZE> dummy_mask = esimd::unpack_mask<MAXROWSIZE>(packed_mask);
      A.merge(std::numeric_limits<int32_t>::max(), dummy_mask);
  }

  esimd::simd_mask<32> flip1 = esimd::unpack_mask<32>(0x66666666); //(init_mask1);

  // stage 0
  bitonic_exchange1(A, B, companionA, companionB, flip1);

  // stage 1
  esimd::simd_mask<32> flip2 = esimd::unpack_mask<32>(0x3c3c3c3c); //(init_mask2);
  esimd::simd_mask<32> flip3 = esimd::unpack_mask<32>(0x5a5a5a5a); //(init_mask3);
  bitonic_exchange2(B, A, companionB, companionA, flip2);
  bitonic_exchange1(A, B, companionA, companionB, flip3);

  // stage 2
  esimd::simd_mask<32> flip4 = esimd::unpack_mask<32>(0x0ff00ff0); //(init_mask4);
  esimd::simd_mask<32> flip5 = esimd::unpack_mask<32>(0x33cc33cc); //(init_mask5);
  esimd::simd_mask<32> flip6 = esimd::unpack_mask<32>(0x55aa55aa); //(init_mask6);
  bitonic_exchange4(B, A, companionB, companionA, flip4);
  bitonic_exchange2(A, B, companionA, companionB, flip5);
  bitonic_exchange1(B, A, companionB, companionA, flip6);

  // stage 3
  esimd::simd_mask<32> flip7 = esimd::unpack_mask<32>(0x00ffff00);  //(init_mask7);
  esimd::simd_mask<32> flip8 = esimd::unpack_mask<32>(0x0f0ff0f0);  //(init_mask8);
  esimd::simd_mask<32> flip9 = esimd::unpack_mask<32>(0x3333cccc);  //(init_mask9);
  esimd::simd_mask<32> flip10 = esimd::unpack_mask<32>(0x5555aaaa); //(init_mask10);
  bitonic_exchange8(A, B, companionA, companionB, flip7);
  bitonic_exchange4(B, A, companionB, companionA, flip8);
  bitonic_exchange2(A, B, companionA, companionB, flip9);
  bitonic_exchange1(B, A, companionB, companionA, flip10);

  // stage 4,5,6,7 use generic bitonic_merge routine
  if constexpr (MAXROWSIZE == 32) {
      bitonic_merge(0, A, companionA, 4, 4);
  }
  else {
      const int merge_rounds = std::log2(MAXROWSIZE);
      for (int i = 4; i < merge_rounds; i++)
          bitonic_merge(0, A, companionA, i, i);
  }
}

ESIMD_INLINE void bitonic_exchange_wrap(int32_t *buf,
                                        double *buf_companion,
                                        int32_t size) {
  esimd::simd<int32_t, BASE_SZ> A;
  esimd::simd<double, BASE_SZ> companionA;
  A = gather_read<int32_t, BASE_SZ>(buf, size);
  companionA = gather_read<double, BASE_SZ>(buf_companion, size);

  bitonic_exchange_simd<BASE_SZ>(A, companionA, size);

  scatter_write<int32_t, BASE_SZ>(buf, size, A);
  scatter_write<double, BASE_SZ>(buf_companion, size, companionA);
}
