Login
Artifact [11e67275ac]
Login

Artifact 11e67275acde07ef6badc0500e81afbbd2979cd169b2fdd95f94fccdb3e7d2bf:


#### MIT License
####
#### Copyright (c) 2023-2024 Remilia Scarlet
#### Copyright (c) 2018 Melnik Alexander
####
#### Permission is hereby granted, free of charge, to any person obtaining a
#### copy of this software and associated documentation files (the "Software"),
#### to deal in the Software without restriction, including without limitation
#### the rights to use, copy, modify, merge, publish, distribute, sublicense,
#### and/or sell copies of the Software, and to permit persons to whom the
#### Software is furnished to do so, subject to the following conditions:
####
#### The above copyright notice and this permission notice shall be included in
#### all copies or substantial portions of the Software.
####
#### THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#### IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#### FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#### AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#### LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#### FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
#### DEALINGS IN THE SOFTWARE.
require "./xxhash-common"

####
#### Ported from C#:
#### https://github.com/uranium62/xxHash/tree/6b20e7f7b32dfc29e5019d3d35f5b7270f1656f3
####

module RemiLib::Digest
  class XXHash3 < XXHashInternal
    @[AlwaysInline]
    private def self.xxh64BitsInternal(input : Bytes, seed : UInt64, secret : Bytes) : UInt64
      len = input.size
      case
      when len <= 16 then len0To16_64B(input, secret.to_unsafe, seed)
      when len <= 128 then len17To128_64B(input, secret.to_unsafe, seed)
      when len <= XXH3_MIDSIZE_MAX then len129To240_64B(input, secret.to_unsafe, seed)
      else hashLong64BWithSeed(input, seed, secret)
      end
    end

    @[AlwaysInline]
    private def self.len0To16_64B(input : Bytes, secret : PUInt8, seed : UInt64) : UInt64
      len = input.size
      case
      when len > 8 then len9To116_64B(input, secret, seed)
      when len >= 4 then len4To8_64B(input, secret, seed)
      when len != 0 then len1To3_64B(input, secret, seed)
      else avalancheXXH64(seed ^ (readLE64(secret + 56) ^ readLE64(secret + 64)))
      end
    end

    @[AlwaysInline]
    private def self.len9To116_64B(input : Bytes, secret : PUInt8, seed : UInt64) : UInt64
      inputPtr = input.to_unsafe
      len = input.size
      bitflip1 : UInt64 = (readLE64(secret + 24) ^ readLE64(secret + 32)) &+ seed
      bitflip2 : UInt64 = (readLE64(secret + 40) ^ readLE64(secret + 48)) &- seed
      inputLow : UInt64 = readLE64(inputPtr) ^ bitflip1
      inputHigh : UInt64 = readLE64(inputPtr + (len &- 8)) ^ bitflip2
      acc : UInt64 = len.to_u64! &+
                                 swap64(inputLow) &+ inputHigh &+
                                 mult128Fold64(inputLow, inputHigh)
      avalanche(acc)
    end

    @[AlwaysInline]
    private def self.len4To8_64B(input : Bytes, secret : PUInt8, seed : UInt64) : UInt64
      inputPtr = input.to_unsafe
      len = input.size
      seed ^= swap32(seed.to_u32!).to_u64!.unsafe_shl(32)
      input1 : UInt32 = readLE32(inputPtr)
      input2 : UInt32 = readLE32(inputPtr + (len &- 4))
      bitflip : UInt64 = (readLE64(secret + 8) ^ readLE64(secret + 16)) &- seed
      input64 : UInt64 = input2.to_u64! + input1.to_u64!.unsafe_shl(32)
      keyed : UInt64 = input64 ^ bitflip
      rrmxmx(keyed, len)
    end

    @[AlwaysInline]
    private def self.len1To3_64B(input : Bytes, secret : PUInt8, seed : UInt64) : UInt64
      len = input.size
      c1 : UInt8 = input.unsafe_fetch(0)
      c2 : UInt8 = input.unsafe_fetch(len.unsafe_shr(1))
      c3 : UInt8 = input.unsafe_fetch(len &- 1)
      combined : UInt32 = c1.to_u32!.unsafe_shl(16) |
                          c2.to_u32!.unsafe_shl(24) |
                          c3.to_u32! |
                          len.to_u32!.unsafe_shl(8)
      bitflip : UInt64 = (readLE32(secret) ^ readLE32(secret + 4)).to_u64! &+ seed
      keyed : UInt64 = combined.to_u64! ^ bitflip
      avalancheXXH64(keyed)
    end

    @[AlwaysInline]
    private def self.len17To128_64B(input : Bytes, secret : PUInt8, seed : UInt64) : UInt64
      inputPtr = input.to_unsafe
      len = input.size
      acc : UInt64 = len.to_u64! &* XXH_PRIME64_1

      if len > 32
        if len > 64
          if len > 96
            acc = acc &+ mix16B(inputPtr + 48, secret + 96, seed)
            acc = acc &+ mix16B(inputPtr + (len &- 64), secret + 112, seed)
          end

          acc = acc &+ mix16B(inputPtr + 32, secret + 64, seed)
          acc = acc &+ mix16B(inputPtr + (len &- 48), secret + 80, seed)
        end

        acc = acc &+ mix16B(inputPtr + 16, secret + 32, seed)
        acc = acc &+ mix16B(inputPtr + (len &- 32), secret + 48, seed)
      end

      acc = acc &+ mix16B(inputPtr, secret, seed)
      acc = acc &+ mix16B(inputPtr + (len &- 16), secret + 16, seed)
      avalanche(acc)
    end

    @[AlwaysInline]
    private def self.len129To240_64B(input : Bytes, secret : PUInt8, seed : UInt64) : UInt64
      inputPtr = input.to_unsafe
      len = input.size
      acc : UInt64 = len.to_u64! * XXH_PRIME64_1
      numRounds : Int32 = len.tdiv(16)

      8.times do |i|
        acc = acc + mix16B(inputPtr + (16 * i), secret + (16 &* i), seed)
      end
      acc = avalanche(acc)

      8.upto(numRounds &- 1) do |i|
        acc = acc &+ mix16B(inputPtr + (16 * i), secret + (16 &* (i &- 8)) + XXH3_MIDSIZE_STARTOFFSET, seed)
      end

      acc = acc &+ mix16B(inputPtr + (len &- 16), secret + (XXH3_SECRET_SIZE_MIN &- XXH3_MIDSIZE_LASTOFFSET), seed)
      avalanche(acc)
    end

    # Mixes all bits to finalize *hash*.
    #
    # The final mix ensures that all input bits have a chance to impact any bit
    # in the output digest, resulting in an unbiased distribution.
    @[AlwaysInline]
    private def self.avalanche(h64 : UInt64) : UInt64
      h64 = xorShift64(h64, 37)
      h64 = h64 &* 0x165667919E3779F9_u64
      xorShift64(h64, 32)
    end

    # :ditto:
    @[AlwaysInline]
    private def self.avalancheXXH64(hash : UInt64) : UInt64
      hash ^= hash.unsafe_shr(33)
      hash = hash &* XXH_PRIME64_2
      hash ^= hash.unsafe_shr(29)
      hash = hash &* XXH_PRIME64_3
      hash ^ (hash.unsafe_shr(32))
    end

    @[AlwaysInline]
    private def self.rrmxmx(h64 : UInt64, len : Int) : UInt64
      h64 ^= rotl64(h64, 49) ^ rotl64(h64, 24)
      h64 = h64 &* 0x9FB21C651E98DF25_u64
      h64 ^= h64.unsafe_shr(35) &+ len
      h64 = h64 &* 0x9FB21C651E98DF25_u64
      xorShift64(h64, 28)
    end

    @[AlwaysInline]
    private def self.hashLong64BWithSeed(input : Bytes, seed : UInt64, secret : Bytes) : UInt64
      if seed == 0
        hashLong64BInternal(input, secret)
      else
        customSecret : Bytes = Bytes.new(XXH3_SECRET_DEFAULT_SIZE)
        initCustomSecret(customSecret.to_unsafe, seed)
        hashLong64BInternal(input, customSecret)
      end
    end

    @[AlwaysInline]
    private def self.initCustomSecret(customSecret : PUInt8, seed : UInt64) : Nil
      initCustomSecretScalar(customSecret, seed)
    end

    @[AlwaysInline]
    private def self.initCustomSecretScalar(customSecret : PUInt8, seed : UInt64) : Nil
      secretPtr : PUInt8 = XXH3_SECRET.to_unsafe
      numRounds : Int32 = XXH_SECRET_DEFAULT_SIZE.tdiv(16)
      lo : UInt64 = 0
      hi : UInt64 = 0
      numRounds.times do |i|
        lo = readLE64(secretPtr + 16 &* i) &+ seed
        hi = readLE64(secretPtr + 16 &* i + 8) &- seed
        writeLE64(customSecret + 16 &* i, lo)
        writeLE64(customSecret + 16 &* i + 8, hi)
      end
    end

    @[AlwaysInline]
    private def self.hashLong64BInternal(input : Bytes, secret : Bytes) : UInt64
      src = XXH3_INIT_ACC.to_unsafe
      acc = StaticArray(UInt64, 8).new { |i| src[i].to_u64! }
      hashLongInternalLoop(acc.to_unsafe, input, secret)
      mergeAccs(acc.to_unsafe, secret.to_unsafe + XXH_SECRET_MERGEACCS_START, input.size.to_u64! &* XXH_PRIME64_1)
    end

    @[AlwaysInline]
    private def self.mergeAccs(acc : PUInt64, secret : PUInt8, start : UInt64) : UInt64
      result : UInt64 = start
      4.times do |i|
        result = result &+ mix2Accs(acc + 2 &* i, secret + 16 &* i)
      end
      avalanche(result)
    end

    @[AlwaysInline]
    private def self.mix2Accs(acc : PUInt64, secret : PUInt8) : UInt64
      mult128Fold64(acc[0] ^ readLE64(secret),
                    acc[1] ^ readLE64(secret + 8))
    end

    @[AlwaysInline]
    private def self.hashLongInternalLoop(acc : PUInt64, input : Bytes, secret : Bytes) : Nil
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe
      len : Int32 = input.size
      secretLen : Int32 = secret.size

      numStripesPerBlock : Int32 = (secretLen &- XXH_STRIPE_LEN).tdiv(XXH_SECRET_CONSUME_RATE)
      blockLen : Int32 = XXH_STRIPE_LEN * numStripesPerBlock
      numBlocks : Int32 = (len &- 1).tdiv(blockLen)

      numBlocks.times do |n|
        accumulate(acc, inputPtr + n &* blockLen, secretPtr, numStripesPerBlock)
        scrambleAcc(acc, secretPtr + (secretLen &- XXH_STRIPE_LEN))
      end

      numStripes : Int32 = ((len &- 1) &- (blockLen &* numBlocks)).tdiv(XXH_STRIPE_LEN)
      accumulate(acc, inputPtr + numBlocks &* blockLen, secretPtr, numStripes)

      b : PUInt8 = inputPtr + (len &- XXH_STRIPE_LEN)
      accumulate512(acc, b, secretPtr + (secretLen &- XXH_STRIPE_LEN &- XXH_SECRET_LASTACC_START))
    end

    @[AlwaysInline]
    private def self.accumulate(acc : PUInt64, input : PUInt8, secret : PUInt8, numStripes : Int) : Nil
      numStripes.times do |n|
        inp : PUInt8 = input + n &* XXH_STRIPE_LEN
        accumulate512(acc, inp, secret + n &* XXH_SECRET_CONSUME_RATE)
      end
    end

    # xxh3Accumulate512 is the tightest loop for long inputs, and it is the most
    # optimized.
    #
    # It is a hardened version of UMAC, based off of FARSH's implementation.
    #
    # This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD
    # implementations, and it is ridiculously fast.
    #
    # We harden it by mixing the original input to the accumulators as well as
    # the product.
    #
    # This means that in the (relatively likely) case of a multiply by zero, the
    # original input is preserved.
    #
    # On 128-bit inputs, we swap 64-bit pairs when we add the input to improve
    # cross-pollination, as otherwise the upper and lower halves would be
    # essentially independent.
    #
    # This doesn't matter on 64-bit hashes since they all get merged together in
    # the end, so we skip the extra step.
    #
    # Both XXH3 64-bit and XXH3 128-bit use this subroutine.
    @[AlwaysInline]
    private def self.accumulate512(acc : PUInt64, input : PUInt8, secret : PUInt8) : Nil
      accumulate512Scalar(acc, input, secret)
    end

    # Processes a 64 byte block of data using the scalar path.
    @[AlwaysInline]
    private def self.accumulate512Scalar(acc : PUInt64, input : PUInt8, secret : PUInt8) : Nil
      XXH_ACC_NB.times { |i| scalarRound(acc, input, secret, i) }
    end

    # Scalar round for `#accumulate512Scalar`.
    @[AlwaysInline]
    private def self.scalarRound(acc : PUInt64, input : PUInt8, secret : PUInt8, lane : Int) : Nil
      dataVal : UInt64 = readLE64(input + lane &* 8)
      dataKey : UInt64 = dataVal ^ readLE64(secret + lane &* 8)
      acc[lane ^ 1] = acc[lane ^ 1] &+ dataVal
      acc[lane] = acc[lane] &+ mult32To64(dataKey, dataKey.unsafe_shr(32))
    end

    # Scrambles the accumulators after a large chunk has been read.
    @[AlwaysInline]
    private def self.scrambleAcc(acc : PUInt64, secret : PUInt8) : Nil
      scrambleAccScalar(acc, secret)
    end

    # :ditto:
    @[AlwaysInline]
    private def self.scrambleAccScalar(acc : PUInt64, secret : PUInt8) : Nil
      XXH_ACC_NB.times { |i| scalarScrambleRound(acc, secret, i) }
    end

    # Scalar scramble step for `scrambleAccScalar`.
    @[AlwaysInline]
    private def self.scalarScrambleRound(acc : PUInt64, secret : PUInt8, lane : Int) : Nil
      xacc : PUInt64 = acc
      xsecret : PUInt8 = secret
      key64 : UInt64 = readLE64(xsecret + lane &* 8)
      acc64 : UInt64 = xacc[lane]
      acc64 = xorShift64(acc64, 47)
      acc64 ^= key64
      acc64 = acc64 &* XXH_PRIME32_1
      xacc[lane] = acc64
    end
  end
end