Login
Artifact [a778705df5]
Login

Artifact a778705df5ad2b7dc0f2a0b7194f6f17d51a9745cfa0cb32b900dd1fad6f6d14:


#### MIT License
####
#### Copyright (c) 2023-2024 Remilia Scarlet
#### Copyright (c) 2018 Melnik Alexander
####
#### Permission is hereby granted, free of charge, to any person obtaining a
#### copy of this software and associated documentation files (the "Software"),
#### to deal in the Software without restriction, including without limitation
#### the rights to use, copy, modify, merge, publish, distribute, sublicense,
#### and/or sell copies of the Software, and to permit persons to whom the
#### Software is furnished to do so, subject to the following conditions:
####
#### The above copyright notice and this permission notice shall be included in
#### all copies or substantial portions of the Software.
####
#### THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#### IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#### FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#### AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#### LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
#### FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
#### DEALINGS IN THE SOFTWARE.
require "./xxhash-common"

####
#### Ported from C#:
#### https://github.com/uranium62/xxHash/tree/6b20e7f7b32dfc29e5019d3d35f5b7270f1656f3
####
#### Remi: TODO The original code was written before a native UInt128 type was
#### added to .NET, and so it used its own internal `uint128` type.  This was
#### somewhat retained in the code below.  The messy low/high uint64 code should
#### ideally be converted to native uint128 code.
####

module RemiLib::Digest
  class XXHash128 < XXHashInternal
    private XXH3_ACC_SIZE = 64

    @[AlwaysInline]
    private def self.xxh3Internal128Bit(input : Bytes, seed : UInt64, secret : Bytes) : UInt128
      len = input.size
      case
      when len <= 16 then len0To16(input, secret, seed)
      when len <= 128 then len17To128(input, secret, seed)
      when len <= XXH3_MIDSIZE_MAX then len129To240(input, secret, seed)
      else hashLongWithSeed(input, secret, seed)
      end
    end

    # Mixes all bits to finalize *hash*.
    #
    # The final mix ensures that all input bits have a chance to impact any bit
    # in the output digest, resulting in an unbiased distribution.
    @[AlwaysInline]
    private def self.xxh64Avalanche(hash : UInt64) : UInt64
      hash ^= hash.unsafe_shr(33)
      hash = hash &* XXH_PRIME64_2
      hash ^= hash.unsafe_shr(29)
      hash = hash &* XXH_PRIME64_3
      hash ^ hash.unsafe_shr(32)
    end

    # :ditto:
    @[AlwaysInline]
    private def self.xxh3Avalanche(h64 : UInt64) : UInt64
      h64 = xorShift64(h64, 37)
      h64 = h64 &* 0x165667919E3779F9_u64
      xorShift64(h64, 32)
    end

    @[AlwaysInline]
    private def self.len0To16(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      len = input.size
      case
      when len > 8
        len9To16(input, secret, seed)
      when len >= 4
        len4To8(input, secret, seed)
      when len != 0
        len1To3(input, secret, seed)
      else
        secretPtr : PUInt8 = secret.to_unsafe
        bitflipL : UInt64 = readLE64(secretPtr + 64) ^ readLE64(secretPtr + 72)
        bitflipH : UInt64 = readLE64(secretPtr + 80) ^ readLE64(secretPtr + 88)
        (xxh64Avalanche(seed ^ bitflipH).to_u128! << 64) | xxh64Avalanche(seed ^ bitflipL)
      end
    end

    @[AlwaysInline]
    private def self.len17To128(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      len : Int32 = input.size
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe
      acc : UInt128 = (len.to_u64! &* XXH_PRIME64_1).to_u128!

      if len > 32
        if len > 64
          if len > 96
            acc = mix32B(acc, inputPtr + 48, inputPtr + (len &- 64), secretPtr + 96, seed)
          end
          acc = mix32B(acc, inputPtr + 32, inputPtr + (len &- 48), secretPtr + 64, seed)
        end
        acc = mix32B(acc, inputPtr + 16, inputPtr + (len &- 32), secretPtr + 32, seed)
      end
      acc = mix32B(acc, inputPtr, inputPtr + (len &- 16), secretPtr, seed)

      hashL : UInt64 = low64(acc) &+ high64(acc)
      hashH : UInt64 = (low64(acc) &* XXH_PRIME64_1) &+
                       (high64(acc) &* XXH_PRIME64_4) &+
                       ((len.to_u64! &- seed) &* XXH_PRIME64_2)

      hashL = xxh3Avalanche(hashL)
      hashH = 0u64 &- xxh3Avalanche(hashH)
      makeU128(hashL, hashH)
    end

    @[AlwaysInline]
    private def self.len9To16(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      len : Int32 = input.size
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe

      bitflipL : UInt64 = (readLE64(secretPtr + 32) ^ readLE64(secretPtr + 40)) &- seed
      bitflipH : UInt64 = (readLE64(secretPtr + 48) ^ readLE64(secretPtr + 56)) &+ seed
      inputL : UInt64 = readLE64(inputPtr)
      inputH : UInt64 = readLE64(inputPtr + (len &- 8))

      m128 : UInt128 = mult64To128(inputL ^ inputH ^ bitflipL, XXH_PRIME64_1)
      m128L : UInt64 = low64(m128)
      m128H : UInt64 = high64(m128)

      m128L = m128L &+ (len.to_u64! &- 1).unsafe_shl(54)
      inputH ^= bitflipH

      m128H = m128H &+ inputH &+ mult32To64(inputH.to_u64!, XXH_PRIME32_2.to_u64! &- 1)
      m128L ^= swap64(m128H)

      hash : UInt128 = mult64To128(m128L, XXH_PRIME64_2)
      hashL : UInt64 = low64(hash)
      hashH : UInt64 = high64(hash)

      hashH = hashH &+ (m128H &* XXH_PRIME64_2)
      hashL = xxh3Avalanche(hashL)
      hashH = xxh3Avalanche(hashH)
      makeU128(hashL, hashH)
    end

    @[AlwaysInline]
    private def self.len1To3(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      len : Int32 = input.size
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe

      c1 : UInt8 = inputPtr[0]
      c2 : UInt8 = inputPtr[len.unsafe_shr(1)]
      c3 : UInt8 = inputPtr[len &- 1]

      combinedL : UInt32 = c1.to_u32!.unsafe_shl(16) |
                           c2.to_u32!.unsafe_shl(24) |
                           c3.to_u32! |
                           len.to_u32!.unsafe_shl(8)
      combinedH : UInt32 = rotl32(swap32(combinedL), 13)

      bitflipL : UInt64 = (readLE32(secretPtr) ^ readLE32(secretPtr + 4)).to_u64! &+ seed
      bitflipH : UInt64 = (readLE32(secretPtr + 8) ^ readLE32(secretPtr + 12)).to_u64! &- seed
      keyedL : UInt64 = combinedL.to_u64! ^ bitflipL
      keyedH : UInt64 = combinedH.to_u64! ^ bitflipH

      hashL : UInt64 = xxh64Avalanche(keyedL)
      hashH : UInt64 = xxh64Avalanche(keyedH)
      makeU128(hashL, hashH)
    end

    @[AlwaysInline]
    private def self.len4To8(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      len : Int32 = input.size
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe

      seed ^= swap32(seed.to_u32!).to_u64!.unsafe_shl(32)

      inputL : UInt32 = readLE32(inputPtr)
      inputH : UInt32 = readLE32(inputPtr + (len &- 4))
      input64 : UInt64 = inputL.to_u64! &+ inputH.to_u64!.unsafe_shl(32)
      bitflip : UInt64 = (readLE64(secretPtr + 16) ^ readLE64(secretPtr + 24)) &+ seed
      keyed : UInt64 = input64 ^ bitflip

      m128 : UInt128 = mult64To128(keyed, XXH_PRIME64_1 &+ len.to_u64!.unsafe_shl(2))
      m128L : UInt64 = low64(m128)
      m128H : UInt64 = high64(m128)

      m128H = m128H &+ m128L.unsafe_shl(1)
      m128L ^= m128H.unsafe_shr(3)

      m128L = xorShift64(m128L, 35)
      m128L = m128L &* 0x9FB21C651E98DF25_u64
      m128L = xorShift64(m128L, 28)
      m128H = xxh3Avalanche(m128H)
      makeU128(m128L, m128H)
    end

    @[AlwaysInline]
    private def self.len129To240(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      len : Int32 = input.size
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe

      numRounds : Int32 = len.tdiv(32)
      accL : UInt64 = len.to_u64! &* XXH_PRIME64_1
      accH : UInt64 = 0
      acc : UInt128 = makeU128(accL, accH)
      4.times do |i|
        acc = mix32B(acc,
                     inputPtr + (32 &* i),
                     inputPtr + (32 &* i) + 16,
                     secretPtr + (32 &* i),
                     seed)
      end

      accL = xxh3Avalanche(low64(acc))
      accH = xxh3Avalanche(high64(acc))
      acc = makeU128(accL, accH)

      4.upto(numRounds &- 1) do |i|
        acc = mix32B(acc,
                     inputPtr + (32 &* i),
                     inputPtr + (32 &* i) + 16,
                     secretPtr + (XXH3_MIDSIZE_STARTOFFSET &+ (32 &* (i &- 4))),
                     seed)
      end

      acc = mix32B(acc,
                   inputPtr + (len &- 16),
                   inputPtr + (len &- 32),
                   secretPtr + (XXH3_SECRET_SIZE_MIN &- XXH3_MIDSIZE_LASTOFFSET &- 16),
                   0u64 &- seed)

      accL = low64(acc)
      accH = high64(acc)
      hashL : UInt64 = accL &+ accH
      hashH : UInt64 = (accL &* XXH_PRIME64_1) &+
                       (accH &* XXH_PRIME64_4) &+
                       ((len.to_u64! &- seed) &* XXH_PRIME64_2)
      hashL = xxh3Avalanche(hashL)
      hashH = 0u64 &- xxh3Avalanche(hashH)
      makeU128(hashL, hashH)
    end

    # A bit slower than XXH3_mix16B, but handles multiply by zero better.
    @[AlwaysInline]
    private def self.mix32B(acc : UInt128, input1 : PUInt8, input2 : PUInt8, secret : PUInt8, seed : UInt64) : UInt128
      accL : UInt64 = low64(acc) &+ mix16B(input1, secret, seed)
      accL ^= readLE64(input2) &+ readLE64(input2 + 8)

      accH : UInt64 = high64(acc) &+ mix16B(input2, secret + 16, seed)
      accH ^= readLE64(input1) &+ readLE64(input1 + 8)
      makeU128(accL, accH)
    end

    @[AlwaysInline]
    private def self.hashLongWithSeed(input : Bytes, secret : Bytes, seed : UInt64) : UInt128
      if seed == 0
        hashLongInternal(input, secret)
      else
        customSecret = Slice(UInt8).new(XXH3_SECRET_DEFAULT_SIZE)
        xxh3InitCustomSecret(customSecret.to_unsafe, seed)
        hashLongInternal(input, customSecret)
      end
    end

    @[AlwaysInline]
    private def self.hashLongInternal(input : Bytes, secret : Bytes) : UInt128
      len = input.size.to_u64!
      src = XXH3_INIT_ACC.to_unsafe
      acc = StaticArray(UInt64, 8).new { |i| src[i].to_u64! }
      hashLongInternalLoop(acc.to_unsafe, input, secret)

      secretPtr : PUInt8 = secret.to_unsafe
      accPtr : PUInt64 = acc.to_unsafe
      hashL : UInt64 = mergeAccs(accPtr,
                                 secretPtr + XXH_SECRET_MERGEACCS_START,
                                 len &* XXH_PRIME64_1)
      hashH : UInt64 = mergeAccs(accPtr,
                                 secretPtr + (secret.size.to_u64! &- XXH3_ACC_SIZE &- XXH_SECRET_MERGEACCS_START),
                                 ~(len &* XXH_PRIME64_2))
      makeU128(hashL, hashH)
    end

    @[AlwaysInline]
    private def self.hashLongInternalLoop(acc : PUInt64, input : Bytes, secret : Bytes) : Nil
      inputPtr : PUInt8 = input.to_unsafe
      secretPtr : PUInt8 = secret.to_unsafe
      len : Int32 = input.size
      secretLen : Int32 = secret.size

      numStripesPerBlock : Int32 = (secretLen &- XXH_STRIPE_LEN).tdiv(XXH_SECRET_CONSUME_RATE)
      blockLen : Int32 = XXH_STRIPE_LEN * numStripesPerBlock
      numBlocks : Int32 = (len &- 1).tdiv(blockLen)

      numBlocks.times do |n|
        xxh3Accumulate(acc, inputPtr + n &* blockLen, secretPtr, numStripesPerBlock)
        xxh3ScrambleAcc(acc, secretPtr + (secretLen &- XXH_STRIPE_LEN))
      end

      numStripes : Int32 = ((len &- 1) &- (blockLen &* numBlocks)).tdiv(XXH_STRIPE_LEN)
      xxh3Accumulate(acc, inputPtr + numBlocks &* blockLen, secretPtr, numStripes)

      b : PUInt8 = inputPtr + (len &- XXH_STRIPE_LEN)
      xxh3Accumulate512(acc, b, secretPtr + (secretLen &- XXH_STRIPE_LEN &- XXH_SECRET_LASTACC_START))
    end

    @[AlwaysInline]
    private def self.mergeAccs(acc : PUInt64, secret : PUInt8, start : UInt64) : UInt64
      result : UInt64 = start
      4.times do |i|
        result = result &+ mix2Accs(acc + 2 &* i, secret + 16 &* i)
      end
      xxh3Avalanche(result)
    end

    @[AlwaysInline]
    private def self.mix2Accs(acc : PUInt64, secret : PUInt8) : UInt64
      mult128Fold64(acc[0] ^ readLE64(secret),
                    acc[1] ^ readLE64(secret + 8))
    end

    @[AlwaysInline]
    private def self.xxh3Accumulate(acc : PUInt64, input : PUInt8, secret : PUInt8, numStripes : Int) : Nil
      numStripes.times do |n|
        inp : PUInt8 = input + n &* XXH_STRIPE_LEN
        xxh3Accumulate512(acc, inp, secret + n &* XXH_SECRET_CONSUME_RATE)
      end
    end

    # xxh3Accumulate512 is the tightest loop for long inputs, and it is the most
    # optimized.
    #
    # It is a hardened version of UMAC, based off of FARSH's implementation.
    #
    # This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD
    # implementations, and it is ridiculously fast.
    #
    # We harden it by mixing the original input to the accumulators as well as
    # the product.
    #
    # This means that in the (relatively likely) case of a multiply by zero, the
    # original input is preserved.
    #
    # On 128-bit inputs, we swap 64-bit pairs when we add the input to improve
    # cross-pollination, as otherwise the upper and lower halves would be
    # essentially independent.
    #
    # This doesn't matter on 64-bit hashes since they all get merged together in
    # the end, so we skip the extra step.
    #
    # Both XXH3 64-bit and XXH3 128-bit use this subroutine.
    @[AlwaysInline]
    private def self.xxh3Accumulate512(acc : PUInt64, input : PUInt8, secret : PUInt8) : Nil
      xxh3Accumulate512Scalar(acc, input, secret)
    end

    # Processes a 64 byte block of data using the scalar path.
    @[AlwaysInline]
    private def self.xxh3Accumulate512Scalar(acc : PUInt64, input : PUInt8, secret : PUInt8) : Nil
      XXH_ACC_NB.times { |i| xxh3ScalarRound(acc, input, secret, i) }
    end

    # Scalar round for `#xxh3Accumulate512Scalar`.
    @[AlwaysInline]
    private def self.xxh3ScalarRound(acc : PUInt64, input : PUInt8, secret : PUInt8, lane : Int) : Nil
      dataVal : UInt64 = readLE64(input + lane &* 8)
      dataKey : UInt64 = dataVal ^ readLE64(secret + lane &* 8)
      acc[lane ^ 1] = acc[lane ^ 1] &+ dataVal
      acc[lane] = acc[lane] &+ mult32To64(dataKey, dataKey.unsafe_shr(32))
    end

    # Scrambles the accumulators after a large chunk has been read.
    @[AlwaysInline]
    private def self.xxh3ScrambleAcc(acc : PUInt64, secret : PUInt8) : Nil
      xxh3ScrambleAccScalar(acc, secret)
    end

    # :ditto:
    @[AlwaysInline]
    private def self.xxh3ScrambleAccScalar(acc : PUInt64, secret : PUInt8) : Nil
      XXH_ACC_NB.times { |i| xxh3ScalarScrambleRound(acc, secret, i) }
    end

    # Scalar scramble step for `xxh3ScrambleAccScalar`.
    @[AlwaysInline]
    private def self.xxh3ScalarScrambleRound(acc : PUInt64, secret : PUInt8, lane : Int) : Nil
      xacc : PUInt64 = acc
      xsecret : PUInt8 = secret
      key64 : UInt64 = readLE64(xsecret + lane &* 8)
      acc64 : UInt64 = xacc[lane]
      acc64 = xorShift64(acc64, 47)
      acc64 ^= key64
      acc64 = acc64 &* XXH_PRIME32_1
      xacc[lane] = acc64
    end


    @[AlwaysInline]
    private def self.xxh3InitCustomSecret(customSecret : PUInt8, seed : UInt64) : Nil
      xxh3InitCustomSecretScalar(customSecret, seed)
    end

    @[AlwaysInline]
    private def self.xxh3InitCustomSecretScalar(customSecret : PUInt8, seed : UInt64) : Nil
      secretPtr : PUInt8 = XXH3_SECRET.to_unsafe
      numRounds : Int32 = XXH_SECRET_DEFAULT_SIZE.tdiv(16)
      lo : UInt64 = 0
      hi : UInt64 = 0
      numRounds.times do |i|
        lo = readLE64(secretPtr + 16 &* i) &+ seed
        hi = readLE64(secretPtr + 16 &* i + 8) &- seed
        writeLE64(customSecret + 16 &* i, lo)
        writeLE64(customSecret + 16 &* i + 8, hi)
      end
    end
  end
end