// Copyright 2023 Tencent Inc.  All rights reserved.
/**
 * @file split_string_by_char.cc
 * @brief 测试使用 SIMD 指令加速字符串分割
 * 单字符分割字符串性能结论：
 * 1、在不使用 SIMD 的情况下，使用双指针遍历，也可以达到很好的效果，性能几乎等于使用 SIMD，并且在短文本下性能好于 SIMD；
 * 2、在使用 SIMD 的情况下，
 * （1）针对长字符串，对 SIMD 指令执行完成之后的结果做精确控制分割(代码中的Simd版本)，才能达到最好性能；
 * （2）针对短字符串，对 SIMD 指令的执行可以做一些重复运算（代码中的SimdV2版本），减少精确控制时的判断，能达到最好性能；
 * 3、这里的测试，构造 string_view 和加入到结果集合里也消耗非常多的时间，在注释掉这两操作之后，长文本分割性能数据：

simd cost_ms:25
simd V2 cost_ms:92
no simd V1 cost_ms:157
no simd V2 cost_ms:26
no simd V3 cost_ms:62

在把这两操作加回来后的性能数据

simd cost_ms:253
simd V2 cost_ms:303
no simd V1 cost_ms:358
no simd V2 cost_ms:256
no simd V3 cost_ms:304


 * 4、使用建议：日常用SplitStringV2即可，不需要用 SIMD，代码很复杂，但收益不大。
      例外情况：可以对待分割字符串做修改，此时对掩码的处理可以直接赋值空字符串，此时性能会大幅度提升。
 */

#include <immintrin.h>

#include <bitset>
#include <chrono>
#include <iostream>
#include <string>
#include <vector>

// 用了 token_start 和 token_length 两个临时变量，有非必要的计算，这是性能最差的字符串分割实现
std::vector<std::string_view> SplitString(std::string_view input, char delimiter) {
  std::vector<std::string_view> tokens;
  int token_start = 0;
  int token_length = 0;
  for (size_t i = 0; i < input.size(); ++i) {
    if (input[i] == delimiter) {
      if (token_length != 0) {
        tokens.emplace_back(input.data() + token_start, token_length);
        token_length = 0;
      }
      token_start = i + 1;
    } else {
      ++token_length;
    }
  }
  if (token_length > 0) {
    tokens.emplace_back(input.data() + token_start, token_length);
  }
  return tokens;
}

// 未使用 SIMD 指令，性能最好的字符串分割实现，采用双指针（token_start、p），最大限度的减少计算量
std::vector<std::string_view> SplitStringV2(std::string_view input, char delimiter) {
  std::vector<std::string_view> tokens;
  const char* token_start = input.data();
  const char* p = token_start;
  const char* end_pos = input.data() + input.size();
  for (; p != end_pos; ++p) {
    if (*p == delimiter) {
      if (p > token_start) {
        tokens.emplace_back(token_start, p - token_start);
      }
      token_start = p + 1;
      continue;
    }
  }
  if (p > token_start) {
    tokens.emplace_back(token_start, p - token_start);
  }
  return tokens;
}

// 未使用 SIMD 指令，使用 stl 的 find_first_of 来实现，性能是直接自己遍历（SplitStringV2）的一半
std::vector<std::string_view> SplitStringV3(std::string_view input, char delimiter) {
  std::vector<std::string_view> tokens;
  size_t token_start = 0;
  while (token_start < input.size()) {
    auto token_end = input.find_first_of(delimiter, token_start);
    if (token_end > token_start) {
      tokens.emplace_back(input.substr(token_start, token_end - token_start));
    }
    if (token_end == std::string_view::npos) {
      break;
    }
    token_start = token_end + 1;
  }
  return tokens;
}

// 使用 SIMD 指令，对长文本性能最佳的实现，代码逻辑比较长，需要考虑边界情况
std::vector<std::string_view> SplitStringWithSimd256(std::string_view input, char delimiter) {
  if (input.size() < 32) {
    return SplitStringV2(input, delimiter);
  }

  std::vector<std::string_view> tokens;

  uint32_t end_pos = input.size() >> 5 << 5;
  __m256i cmp_a = _mm256_set1_epi8(delimiter);  // 8bit的分隔重复32次扩充到256bit
  const char* p = input.data();
  const char* end = p + end_pos;
  uint32_t last_lead_zero = 0;  // 上一轮256bit（32个字符）处理后剩下的未拷贝进结果集的字符串个数
  while (p < end) {
    __m256i cmp_b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(p));  // 32个字符加载进内存
    __m256i cmp = _mm256_cmpeq_epi8(cmp_a, cmp_b);  // 256 bit 一次比较
    /*
    int* xx = (int*)&cmp;
    std::cout << "xx[0]" << std::bitset<32>(xx[0]) << std::endl;
    */

    uint32_t mask = _mm256_movemask_epi8(cmp);

    if (mask == 0) {
      last_lead_zero += 32;
      p += 32;
      continue;
    }

    // 记录本次的头部0个数，注：mask的序和字符串序是相反的，所以这里头部的0对应字符串尾部的不匹配字符
    uint32_t lead_zero = __builtin_clz(mask);

    // 补上一次未拷贝的字符串
    uint32_t tail_zero = __builtin_ctz(mask);
    if (last_lead_zero != 0 || tail_zero != 0) {
      tokens.emplace_back(p - last_lead_zero, last_lead_zero + tail_zero);
    }
    mask >>= (tail_zero + 1);
    p += tail_zero + 1;

    // 补完，继续处理
    while (mask != 0) {
      uint32_t tail_zero = __builtin_ctz(mask);
      if (tail_zero != 0) {
        tokens.emplace_back(p, tail_zero);
      }
      mask >>= (tail_zero + 1);
      p += tail_zero + 1;
    }

    last_lead_zero = lead_zero;
    p += lead_zero;
  }

  // 256 bit（32字节） 对齐之后剩下的部分
  const char* token_start = input.data() + end_pos - last_lead_zero;
  const char* pp = token_start;
  const char* sentence_end = input.data() + input.size();
  for (; pp != sentence_end; ++pp) {
    if (*pp == delimiter) {
      if (pp > token_start) {
        tokens.emplace_back(token_start, pp - token_start);
      }
      token_start = pp + 1;
      continue;
    }
  }
  if (pp > token_start) {
    tokens.emplace_back(token_start, pp - token_start);
  }
  return tokens;
}

// 使用 SIMD 指令，每次并行指令只处理一个分隔符，因此计算存在浪费，但如果文本长度只有 32+ 个，那么性能比较好
std::vector<std::string_view> SplitStringWithSimd256V2(std::string_view input, char delimiter) {
  if (input.size() < 32) {
    return SplitStringV2(input, delimiter);
  }

  std::vector<std::string_view> tokens;
  __m256i cmp_a = _mm256_set1_epi8(delimiter);  // 8bit的分隔重复32次扩充到256bit
  const char* p = input.data();
  uint32_t last_lead_zero = 0;  // 上一轮256bit（32个字符）处理后剩下的未拷贝进结果集的字符串个数
  while (p + 32 < input.data() + input.size()) {
    __m256i cmp_b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(p));  // 32个字符加载进内存
    __m256i cmp = _mm256_cmpeq_epi8(cmp_a, cmp_b);  // 256 bit 一次比较
    uint32_t mask = _mm256_movemask_epi8(cmp);
    if (mask == 0) {
      last_lead_zero += 32;
      p += 32;
      continue;
    }

    // 补上一次未拷贝的字符串
    uint32_t tail_zero = __builtin_ctz(mask);
    if (last_lead_zero != 0 || tail_zero != 0) {
      tokens.emplace_back(p - last_lead_zero, last_lead_zero + tail_zero);
      last_lead_zero = 0;
    }
    p += tail_zero + 1;
  }

  // 不足 256 bit（32字节）部分
  const char* token_start = p - last_lead_zero;
  const char* sentence_end = input.data() + input.size();
  for (; p != sentence_end; ++p) {
    if (*p == delimiter) {
      if (p > token_start) {
        tokens.emplace_back(token_start, p - token_start);
      }
      token_start = p + 1;
      continue;
    }
  }
  if (p > token_start) {
    tokens.emplace_back(token_start, p - token_start);
  }
  return tokens;
}

int main() {
  std::vector<std::string_view> sentences;
  // sentences.emplace_back(",,,,,,,,,,");
  // sentences.emplace_back(",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,");
  // sentences.emplace_back(",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,");
  // sentences.emplace_back("1111111111");
  // sentences.emplace_back("1111111111111111111111111111111111111");
  // sentences.emplace_back("11,1111111111111,11111111111111111111111,11");
  // sentences.emplace_back("11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111");
  sentences.emplace_back(
      "11 11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11 11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,1111111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1111111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "11,11,11,111,111,111,111,1,1,1,,1,11,1,1111111111111,1,1,11111111,"
      "111111111111111111111111111111111111111111,111111111");
  // SplitStringWithSimd256
  {
    // 单元测试
    for (const std::string_view& item : sentences) {
      auto output_1 = SplitStringWithSimd256(item, ',');
      auto output_2 = SplitString(item, ',');
      for (int i = 0; i != output_1.size(); ++i) {
        if (output_1[i] != output_2[i]) {
          std::cout << "SplitStringWithSimd256 has error" << std::endl;
          std::cout << i << ":" << output_1[i] << std::endl;
          std::cout << i << ":" << output_2[i] << std::endl;
        }
      }
    }

    auto start_commit = std::chrono::steady_clock::now().time_since_epoch();
    for (int i = 0; i != 10000; ++i) {
      for (const std::string_view& item : sentences) {
        SplitStringWithSimd256(item, ',');
      }
    }
    auto end_commit = std::chrono::steady_clock::now().time_since_epoch();
    std::cout << "simd cost_ms:"
              << std::chrono::duration_cast<std::chrono::milliseconds>(end_commit - start_commit).count() << std::endl;
  }

  // SplitStringWithSimd256V2
  {
    // 单元测试
    for (const std::string_view& item : sentences) {
      auto output_1 = SplitStringWithSimd256V2(item, ',');
      auto output_2 = SplitString(item, ',');
      for (int i = 0; i != output_1.size(); ++i) {
        if (output_1[i] != output_2[i]) {
          std::cout << "SplitStringWithSimd256V2 has error" << std::endl;
        }
      }
    }

    auto start_commit = std::chrono::steady_clock::now().time_since_epoch();
    for (int i = 0; i != 10000; ++i) {
      for (const std::string_view& item : sentences) {
        SplitStringWithSimd256V2(item, ',');
      }
    }
    auto end_commit = std::chrono::steady_clock::now().time_since_epoch();
    std::cout << "simd V2 cost_ms:"
              << std::chrono::duration_cast<std::chrono::milliseconds>(end_commit - start_commit).count() << std::endl;
  }

  {
    auto start_commit = std::chrono::steady_clock::now().time_since_epoch();
    for (int i = 0; i != 10000; ++i) {
      for (const std::string_view& item : sentences) {
        SplitString(item, ',');
      }
    }
    auto end_commit = std::chrono::steady_clock::now().time_since_epoch();
    std::cout << "no simd V1 cost_ms:"
              << std::chrono::duration_cast<std::chrono::milliseconds>(end_commit - start_commit).count() << std::endl;
  }

  // SplitStringV2
  {
    // 单元测试
    for (const std::string_view& item : sentences) {
      auto output_1 = SplitStringV2(item, ',');
      auto output_2 = SplitString(item, ',');
      for (int i = 0; i != output_1.size(); ++i) {
        if (output_1[i] != output_2[i]) {
          std::cout << "SplitStringV2 has error" << std::endl;
        }
      }
    }

    auto start_commit = std::chrono::steady_clock::now().time_since_epoch();
    for (int i = 0; i != 10000; ++i) {
      for (const std::string_view& item : sentences) {
        SplitStringV2(item, ',');
      }
    }
    auto end_commit = std::chrono::steady_clock::now().time_since_epoch();
    std::cout << "no simd V2 cost_ms:"
              << std::chrono::duration_cast<std::chrono::milliseconds>(end_commit - start_commit).count() << std::endl;
  }

  // SplitStringV3
  {
    // 单元测试
    for (const std::string_view& item : sentences) {
      auto output_1 = SplitStringV3(item, ',');
      auto output_2 = SplitString(item, ',');
      if (output_1.size() != output_2.size()) {
        std::cout << "SplitStringV3 has error" << std::endl;
        return 0;
      }
      for (int i = 0; i != output_1.size(); ++i) {
        if (output_1[i] != output_2[i]) {
          std::cout << "SplitStringV3 has error" << std::endl;
          std::cout << i << ":" << output_1[i] << std::endl;
          std::cout << i << ":" << output_2[i] << std::endl;
        }
      }
    }

    auto start_commit = std::chrono::steady_clock::now().time_since_epoch();
    for (int i = 0; i != 10000; ++i) {
      for (const std::string_view& item : sentences) {
        SplitStringV3(item, ',');
      }
    }
    auto end_commit = std::chrono::steady_clock::now().time_since_epoch();
    std::cout << "no simd V3 cost_ms:"
              << std::chrono::duration_cast<std::chrono::milliseconds>(end_commit - start_commit).count() << std::endl;
  }

  return 0;
}