001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.codec.digest;
018
019import java.util.Arrays;
020import java.util.Objects;
021
022/**
023 * Implements the Blake3 algorithm providing a {@linkplain #initHash() hash function} with extensible output (XOF), a
024 * {@linkplain #initKeyedHash(byte[]) keyed hash function} (MAC, PRF), and a
025 * {@linkplain #initKeyDerivationFunction(byte[]) key derivation function} (KDF). Blake3 has a 128-bit security level
026 * and a default output length of 256 bits (32 bytes) which can extended up to 2<sup>64</sup> bytes.
027 * <h2>Hashing</h2>
028 * <p>Hash mode calculates the same output hash given the same input bytes and can be used as both a message digest and
029 * and extensible output function.</p>
030 * <pre>{@code
031 *      Blake3 hasher = Blake3.initHash();
032 *      hasher.update("Hello, world!".getBytes(StandardCharsets.UTF_8));
033 *      byte[] hash = new byte[32];
034 *      hasher.doFinalize(hash);
035 * }</pre>
036 * <h2>Keyed Hashing</h2>
037 * <p>Keyed hashes take a 32-byte secret key and calculates a message authentication code on some input bytes. These
038 * also work as pseudo-random functions (PRFs) with extensible output similar to the extensible hash output. Note that
039 * Blake3 keyed hashes have the same performance as plain hashes; the key is used in initialization in place of a
040 * standard initialization vector used for plain hashing.</p>
041 * <pre>{@code
042 *      SecureRandom random = SecureRandom.getInstanceStrong();
043 *      byte[] key = new byte[32];
044 *      random.nextBytes(key);
045 *      Blake3 hasher = Blake3.initKeyedHash(key);
046 *      hasher.update("Hello, Alice!".getBytes(StandardCharsets.UTF_8));
047 *      byte[] mac = new byte[32];
048 *      hasher.doFinalize(mac);
049 * }</pre>
050 * <h2>Key Derivation</h2>
051 * <p>A specific hash mode for deriving session keys and other derived keys in a unique key derivation context
052 * identified by some sequence of bytes. These context strings should be unique but do not need to be kept secret.
053 * Additional input data is hashed for key material which can be finalized to derive subkeys.</p>
054 * <pre>{@code
055 *      String context = "org.apache.commons.codec.digest.Blake3Example";
056 *      byte[] sharedSecret = ...;
057 *      byte[] senderId = ...;
058 *      byte[] recipientId = ...;
059 *      Blake3 kdf = Blake3.initKeyDerivationFunction(context.getBytes(StandardCharsets.UTF_8));
060 *      kdf.update(sharedSecret);
061 *      kdf.update(senderId);
062 *      kdf.update(recipientId);
063 *      byte[] txKey = new byte[32];
064 *      byte[] rxKey = new byte[32];
065 *      kdf.doFinalize(txKey);
066 *      kdf.doFinalize(rxKey);
067 * }</pre>
068 * <p>
069 * Adapted from the ISC-licensed O(1) Cryptography library by Matt Sicker and ported from the reference public domain
070 * implementation by Jack O'Connor.
071 * </p>
072 *
073 * @see <a href="https://github.com/BLAKE3-team/BLAKE3">BLAKE3 hash function</a>
074 * @since 1.16
075 */
076public final class Blake3 {
077
078    private static final class ChunkState {
079        private int[] chainingValue;
080        private final long chunkCounter;
081        private final int flags;
082
083        private final byte[] block = new byte[BLOCK_LEN];
084        private int blockLength;
085        private int blocksCompressed;
086
087        private ChunkState(final int[] key, final long chunkCounter, final int flags) {
088            chainingValue = key;
089            this.chunkCounter = chunkCounter;
090            this.flags = flags;
091        }
092
093        private int length() {
094            return BLOCK_LEN * blocksCompressed + blockLength;
095        }
096
097        private Output output() {
098            final int[] blockWords = unpackInts(block, BLOCK_INTS);
099            final int outputFlags = flags | startFlag() | CHUNK_END;
100            return new Output(chainingValue, blockWords, chunkCounter, blockLength, outputFlags);
101        }
102
103        private int startFlag() {
104            return blocksCompressed == 0 ? CHUNK_START : 0;
105        }
106
107        private void update(final byte[] input, int offset, int length) {
108            while (length > 0) {
109                if (blockLength == BLOCK_LEN) {
110                    // If the block buffer is full, compress it and clear it. More
111                    // input is coming, so this compression is not CHUNK_END.
112                    final int[] blockWords = unpackInts(block, BLOCK_INTS);
113                    chainingValue = Arrays.copyOf(
114                            compress(chainingValue, blockWords, BLOCK_LEN, chunkCounter, flags | startFlag()),
115                            CHAINING_VALUE_INTS);
116                    blocksCompressed++;
117                    blockLength = 0;
118                    Arrays.fill(block, (byte) 0);
119                }
120
121                final int want = BLOCK_LEN - blockLength;
122                final int take = Math.min(want, length);
123                System.arraycopy(input, offset, block, blockLength, take);
124                blockLength += take;
125                offset += take;
126                length -= take;
127            }
128        }
129    }
130    private static final class EngineState {
131        private final int[] key;
132        private final int flags;
133        // Space for 54 subtree chaining values: 2^54 * CHUNK_LEN = 2^64
134        // No more than 54 entries can ever be added to this stack (after updating 2^64 bytes and not finalizing any)
135        // so we preallocate the stack here. This can be smaller in environments where the data limit is expected to
136        // be much lower.
137        private final int[][] cvStack = new int[54][];
138        private int stackLen;
139        private ChunkState state;
140
141        private EngineState(final int[] key, final int flags) {
142            this.key = key;
143            this.flags = flags;
144            state = new ChunkState(key, 0, flags);
145        }
146
147        // Section 5.1.2 of the BLAKE3 spec explains this algorithm in more detail.
148        private void addChunkCV(final int[] firstCV, final long totalChunks) {
149            // This chunk might complete some subtrees. For each completed subtree,
150            // its left child will be the current top entry in the CV stack, and
151            // its right child will be the current value of `newCV`. Pop each left
152            // child off the stack, merge it with `newCV`, and overwrite `newCV`
153            // with the result. After all these merges, push the final value of
154            // `newCV` onto the stack. The number of completed subtrees is given
155            // by the number of trailing 0-bits in the new total number of chunks.
156            int[] newCV = firstCV;
157            long chunkCounter = totalChunks;
158            while ((chunkCounter & 1) == 0) {
159                newCV = parentChainingValue(popCV(), newCV, key, flags);
160                chunkCounter >>= 1;
161            }
162            pushCV(newCV);
163        }
164
165        private void inputData(final byte[] in, int offset, int length) {
166            while (length > 0) {
167                // If the current chunk is complete, finalize it and reset the
168                // chunk state. More input is coming, so this chunk is not ROOT.
169                if (state.length() == CHUNK_LEN) {
170                    final int[] chunkCV = state.output().chainingValue();
171                    final long totalChunks = state.chunkCounter + 1;
172                    addChunkCV(chunkCV, totalChunks);
173                    state = new ChunkState(key, totalChunks, flags);
174                }
175
176                // Compress input bytes into the current chunk state.
177                final int want = CHUNK_LEN - state.length();
178                final int take = Math.min(want, length);
179                state.update(in, offset, take);
180                offset += take;
181                length -= take;
182            }
183        }
184
185        private void outputHash(final byte[] out, final int offset, final int length) {
186            // Starting with the Output from the current chunk, compute all the
187            // parent chaining values along the right edge of the tree, until we
188            // have the root Output.
189            Output output = state.output();
190            int parentNodesRemaining = stackLen;
191            while (parentNodesRemaining-- > 0) {
192                final int[] parentCV = cvStack[parentNodesRemaining];
193                output = parentOutput(parentCV, output.chainingValue(), key, flags);
194            }
195            output.rootOutputBytes(out, offset, length);
196        }
197
198        private int[] popCV() {
199            return cvStack[--stackLen];
200        }
201
202        private void pushCV(final int[] cv) {
203            cvStack[stackLen++] = cv;
204        }
205
206        private void reset() {
207            stackLen = 0;
208            Arrays.fill(cvStack, null);
209            state = new ChunkState(key, 0, flags);
210        }
211    }
212    /**
213     * Represents the state just prior to either producing an eight word chaining value or any number of output bytes
214     * when the ROOT flag is set.
215     */
216    private static final class Output {
217        private final int[] inputChainingValue;
218        private final int[] blockWords;
219        private final long counter;
220        private final int blockLength;
221        private final int flags;
222
223        private Output(
224                final int[] inputChainingValue, final int[] blockWords, final long counter, final int blockLength,
225                final int flags) {
226            this.inputChainingValue = inputChainingValue;
227            this.blockWords = blockWords;
228            this.counter = counter;
229            this.blockLength = blockLength;
230            this.flags = flags;
231        }
232
233        private int[] chainingValue() {
234            return Arrays
235                    .copyOf(compress(inputChainingValue, blockWords, blockLength, counter, flags), CHAINING_VALUE_INTS);
236        }
237
238        private void rootOutputBytes(final byte[] out, int offset, int length) {
239            int outputBlockCounter = 0;
240            while (length > 0) {
241                int chunkLength = Math.min(OUT_LEN * 2, length);
242                length -= chunkLength;
243                final int[] words =
244                        compress(inputChainingValue, blockWords, blockLength, outputBlockCounter++, flags | ROOT);
245                int wordCounter = 0;
246                while (chunkLength > 0) {
247                    final int wordLength = Math.min(Integer.BYTES, chunkLength);
248                    packInt(words[wordCounter++], out, offset, wordLength);
249                    offset += wordLength;
250                    chunkLength -= wordLength;
251                }
252            }
253        }
254    }
255    private static final int BLOCK_LEN = 64;
256    private static final int BLOCK_INTS = BLOCK_LEN / Integer.BYTES;
257    private static final int KEY_LEN = 32;
258    private static final int KEY_INTS = KEY_LEN / Integer.BYTES;
259
260    private static final int OUT_LEN = 32;
261
262    private static final int CHUNK_LEN = 1024;
263    private static final int CHAINING_VALUE_INTS = 8;
264    /**
265     * Standard hash key used for plain hashes; same initialization vector as Blake2s.
266     */
267    private static final int[] IV =
268            { 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 };
269    // domain flags
270    private static final int CHUNK_START = 1;
271    private static final int CHUNK_END = 1 << 1;
272    private static final int PARENT = 1 << 2;
273    private static final int ROOT = 1 << 3;
274
275    private static final int KEYED_HASH = 1 << 4;
276
277    private static final int DERIVE_KEY_CONTEXT = 1 << 5;
278
279    private static final int DERIVE_KEY_MATERIAL = 1 << 6;
280
281    /**
282     * Pre-permuted for all 7 rounds; the second row (2,6,3,...) indicates the base permutation.
283     */
284    private static final byte[][] MSG_SCHEDULE = {
285            { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 },
286            { 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 },
287            { 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 },
288            { 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 },
289            { 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 },
290            { 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 },
291            { 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 }
292    };
293
294    private static void checkBufferArgs(final byte[] buffer, final int offset, final int length) {
295        Objects.requireNonNull(buffer);
296        if (offset < 0) {
297            throw new IndexOutOfBoundsException("Offset must be non-negative");
298        }
299        if (length < 0) {
300            throw new IndexOutOfBoundsException("Length must be non-negative");
301        }
302        final int bufferLength = buffer.length;
303        if (offset > bufferLength - length) {
304            throw new IndexOutOfBoundsException(
305                    "Offset " + offset + " and length " + length + " out of bounds with buffer length " + bufferLength);
306        }
307    }
308
309    private static int[] compress(
310            final int[] chainingValue, final int[] blockWords, final int blockLength, final long counter,
311            final int flags) {
312        final int[] state = Arrays.copyOf(chainingValue, BLOCK_INTS);
313        System.arraycopy(IV, 0, state, 8, 4);
314        state[12] = (int) counter;
315        state[13] = (int) (counter >> Integer.SIZE);
316        state[14] = blockLength;
317        state[15] = flags;
318        for (int i = 0; i < 7; i++) {
319            final byte[] schedule = MSG_SCHEDULE[i];
320            round(state, blockWords, schedule);
321        }
322        for (int i = 0; i < state.length / 2; i++) {
323            state[i] ^= state[i + 8];
324            state[i + 8] ^= chainingValue[i];
325        }
326        return state;
327    }
328
329    /**
330     * The mixing function, G, which mixes either a column or a diagonal.
331     */
332    private static void g(
333            final int[] state, final int a, final int b, final int c, final int d, final int mx, final int my) {
334        state[a] += state[b] + mx;
335        state[d] = Integer.rotateRight(state[d] ^ state[a], 16);
336        state[c] += state[d];
337        state[b] = Integer.rotateRight(state[b] ^ state[c], 12);
338        state[a] += state[b] + my;
339        state[d] = Integer.rotateRight(state[d] ^ state[a], 8);
340        state[c] += state[d];
341        state[b] = Integer.rotateRight(state[b] ^ state[c], 7);
342    }
343
344    /**
345     * Calculates the Blake3 hash of the provided data.
346     *
347     * @param data source array to absorb data from
348     * @return 32-byte hash squeezed from the provided data
349     * @throws NullPointerException if data is null
350     */
351    public static byte[] hash(final byte[] data) {
352        return Blake3.initHash().update(data).doFinalize(OUT_LEN);
353    }
354
355    /**
356     * Constructs a fresh Blake3 hash function. The instance returned functions as an arbitrary length message digest.
357     *
358     * @return fresh Blake3 instance in hashed mode
359     */
360    public static Blake3 initHash() {
361        return new Blake3(IV, 0);
362    }
363
364    /**
365     * Constructs a fresh Blake3 key derivation function using the provided key derivation context byte string.
366     * The instance returned functions as a key-derivation function which can further absorb additional context data
367     * before squeezing derived key data.
368     *
369     * @param kdfContext a globally unique key-derivation context byte string to separate key derivation contexts from each other
370     * @return fresh Blake3 instance in key derivation mode
371     * @throws NullPointerException if kdfContext is null
372     */
373    public static Blake3 initKeyDerivationFunction(final byte[] kdfContext) {
374        Objects.requireNonNull(kdfContext);
375        final EngineState kdf = new EngineState(IV, DERIVE_KEY_CONTEXT);
376        kdf.inputData(kdfContext, 0, kdfContext.length);
377        final byte[] key = new byte[KEY_LEN];
378        kdf.outputHash(key, 0, key.length);
379        return new Blake3(unpackInts(key, KEY_INTS), DERIVE_KEY_MATERIAL);
380    }
381
382    /**
383     * Constructs a fresh Blake3 keyed hash function. The instance returned functions as a pseudorandom function (PRF) or as a
384     * message authentication code (MAC).
385     *
386     * @param key 32-byte secret key
387     * @return fresh Blake3 instance in keyed mode using the provided key
388     * @throws NullPointerException     if key is null
389     * @throws IllegalArgumentException if key is not 32 bytes
390     */
391    public static Blake3 initKeyedHash(final byte[] key) {
392        Objects.requireNonNull(key);
393        if (key.length != KEY_LEN) {
394            throw new IllegalArgumentException("Blake3 keys must be 32 bytes");
395        }
396        return new Blake3(unpackInts(key, KEY_INTS), KEYED_HASH);
397    }
398
399    /**
400     * Calculates the Blake3 keyed hash (MAC) of the provided data.
401     *
402     * @param key  32-byte secret key
403     * @param data source array to absorb data from
404     * @return 32-byte mac squeezed from the provided data
405     * @throws NullPointerException if key or data are null
406     */
407    public static byte[] keyedHash(final byte[] key, final byte[] data) {
408        return Blake3.initKeyedHash(key).update(data).doFinalize(OUT_LEN);
409    }
410
411    private static void packInt(final int value, final byte[] dst, final int off, final int len) {
412        for (int i = 0; i < len; i++) {
413            dst[off + i] = (byte) (value >>> i * Byte.SIZE);
414        }
415    }
416
417    private static int[] parentChainingValue(
418            final int[] leftChildCV, final int[] rightChildCV, final int[] key, final int flags) {
419        return parentOutput(leftChildCV, rightChildCV, key, flags).chainingValue();
420    }
421
422    private static Output parentOutput(
423            final int[] leftChildCV, final int[] rightChildCV, final int[] key, final int flags) {
424        final int[] blockWords = Arrays.copyOf(leftChildCV, BLOCK_INTS);
425        System.arraycopy(rightChildCV, 0, blockWords, 8, CHAINING_VALUE_INTS);
426        return new Output(key.clone(), blockWords, 0, BLOCK_LEN, flags | PARENT);
427    }
428
429    private static void round(final int[] state, final int[] msg, final byte[] schedule) {
430        // Mix the columns.
431        g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
432        g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
433        g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
434        g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
435
436        // Mix the diagonals.
437        g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
438        g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
439        g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
440        g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
441    }
442
443    private static int unpackInt(final byte[] buf, final int off) {
444        return buf[off] & 0xFF | (buf[off + 1] & 0xFF) << 8 | (buf[off + 2] & 0xFF) << 16 | (buf[off + 3] & 0xFF) << 24;
445    }
446
447    private static int[] unpackInts(final byte[] buf, final int nrInts) {
448        final int[] values = new int[nrInts];
449        for (int i = 0, off = 0; i < nrInts; i++, off += Integer.BYTES) {
450            values[i] = unpackInt(buf, off);
451        }
452        return values;
453    }
454
455    private final EngineState engineState;
456
457    private Blake3(final int[] key, final int flags) {
458        engineState = new EngineState(key, flags);
459    }
460
461    /**
462     * Finalizes hash output data that depends on the sequence of updated bytes preceding this invocation and any
463     * previously finalized bytes. Note that this can finalize up to 2<sup>64</sup> bytes per instance.
464     *
465     * @param out destination array to finalize bytes into
466     * @return this
467     * @throws NullPointerException if out is null
468     */
469    public Blake3 doFinalize(final byte[] out) {
470        return doFinalize(out, 0, out.length);
471    }
472
473    /**
474     * Finalizes an arbitrary number of bytes into the provided output array that depends on the sequence of previously
475     * updated and finalized bytes. Note that this can finalize up to 2<sup>64</sup> bytes per instance.
476     *
477     * @param out    destination array to finalize bytes into
478     * @param offset where in the array to begin writing bytes to
479     * @param length number of bytes to finalize
480     * @return this
481     * @throws NullPointerException      if out is null
482     * @throws IndexOutOfBoundsException if offset or length are negative or if offset + length is greater than the
483     *                                   length of the provided array
484     */
485    public Blake3 doFinalize(final byte[] out, final int offset, final int length) {
486        checkBufferArgs(out, offset, length);
487        engineState.outputHash(out, offset, length);
488        return this;
489    }
490
491    /**
492     * Squeezes and returns an arbitrary number of bytes dependent on the sequence of previously absorbed and squeezed bytes.
493     *
494     * @param nrBytes number of bytes to finalize
495     * @return requested number of finalized bytes
496     * @throws IllegalArgumentException if nrBytes is negative
497     */
498    public byte[] doFinalize(final int nrBytes) {
499        if (nrBytes < 0) {
500            throw new IllegalArgumentException("Requested bytes must be non-negative");
501        }
502        final byte[] hash = new byte[nrBytes];
503        doFinalize(hash);
504        return hash;
505    }
506
507    /**
508     * Resets this instance back to its initial state when it was first constructed.
509     * @return this
510     */
511    public Blake3 reset() {
512        engineState.reset();
513        return this;
514    }
515
516    /**
517     * Updates this hash state using the provided bytes.
518     *
519     * @param in source array to update data from
520     * @return this
521     * @throws NullPointerException if in is null
522     */
523    public Blake3 update(final byte[] in) {
524        return update(in, 0, in.length);
525    }
526
527    /**
528     * Updates this hash state using the provided bytes at an offset.
529     *
530     * @param in     source array to update data from
531     * @param offset where in the array to begin reading bytes
532     * @param length number of bytes to update
533     * @return this
534     * @throws NullPointerException      if in is null
535     * @throws IndexOutOfBoundsException if offset or length are negative or if offset + length is greater than the
536     *                                   length of the provided array
537     */
538    public Blake3 update(final byte[] in, final int offset, final int length) {
539        checkBufferArgs(in, offset, length);
540        engineState.inputData(in, offset, length);
541        return this;
542    }
543
544}