nockchain-grpc/crypto/cheetah.go
2025-10-06 13:38:53 +07:00

309 lines
6.1 KiB
Go

package crypto
import (
"encoding/binary"
"fmt"
"math/big"
"slices"
)
var (
BELT_ZERO = Belt{Value: 0}
BELT_ONE = Belt{Value: 1}
F6_ZERO = [6]Belt{BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO}
F6_ONE = [6]Belt{BELT_ONE, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO}
A_ID = CheetahPoint{
X: F6_ZERO,
Y: F6_ONE,
Inf: true,
}
A_GEN = CheetahPoint{
X: [6]Belt{
{Value: 2754611494552410273},
{Value: 8599518745794843693},
{Value: 10526511002404673680},
{Value: 4830863958577994148},
{Value: 375185138577093320},
{Value: 12938930721685970739},
},
Y: [6]Belt{
{Value: 15384029202802550068},
{Value: 2774812795997841935},
{Value: 14375303400746062753},
{Value: 10708493419890101954},
{Value: 13187678623570541764},
{Value: 9990732138772505951},
},
Inf: false,
}
P_BIG = new(big.Int).SetUint64(PRIME)
P_BIG_2 = new(big.Int).Mul(P_BIG, P_BIG)
P_BIG_3 = new(big.Int).Mul(P_BIG_2, P_BIG)
G_ORDER, _ = new(big.Int).SetString("55610362957290864006699123731285679659474893560816383126640993521607086746831", 10)
)
type CheetahPoint struct {
X, Y [6]Belt
Inf bool
}
func (p *CheetahPoint) Bytes() []byte {
bytes := []byte{0x1}
for i := 5; i >= 0; i-- {
belt := p.Y[i].Value
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, belt)
bytes = append(bytes, buf...)
}
for i := 5; i >= 0; i-- {
belt := p.X[i].Value
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, belt)
bytes = append(bytes, buf...)
}
return bytes
}
func CheetaPointFromBytes(bytes []byte) (CheetahPoint, error) {
if len(bytes) != 97 {
return CheetahPoint{}, fmt.Errorf("invalid point length %d, length must be 97", len(bytes))
}
bytes = bytes[1:]
if len(bytes)%8 != 0 {
return CheetahPoint{}, fmt.Errorf("input length not multiple of 8")
}
n := len(bytes) / 8
belts := []Belt{}
for i := 0; i < n; i++ {
chunk := bytes[i*8 : (i+1)*8]
belt := binary.BigEndian.Uint64(chunk)
belts = append(belts, Belt{Value: belt})
}
slices.Reverse(belts)
point := CheetahPoint{
X: [6]Belt(belts[0:6]),
Y: [6]Belt(belts[6:]),
Inf: false,
}
if point.InCurve() {
return point, nil
}
return CheetahPoint{}, fmt.Errorf("point not in curve")
}
func (p *CheetahPoint) InCurve() bool {
if *p == A_ID {
return true
}
scaled := CheetahScaleBig(*p, *G_ORDER)
return scaled == A_ID
}
func CheetahScaleBig(p CheetahPoint, n big.Int) CheetahPoint {
zero := big.NewInt(0)
nCopy := new(big.Int).Set(&n)
acc := A_ID
for nCopy.Cmp(zero) > 0 {
if nCopy.Bit(0) == 1 {
acc = CheetahAdd(acc, p)
}
p = CheetahDouble(p)
nCopy.Rsh(nCopy, 1)
}
return acc
}
func CheetahAdd(p, q CheetahPoint) CheetahPoint {
if p.Inf {
return q
}
if q.Inf {
return p
}
if p == CheetahNeg(q) {
return A_ID
}
if p == q {
return CheetahDouble(p)
}
return cheetahAddUnsafe(&p, &q)
}
func CheetahNeg(p CheetahPoint) CheetahPoint {
negP := p
for i := 0; i < 6; i++ {
negP.Y[i] = p.Y[i].Neg()
}
return negP
}
func CheetahDouble(p CheetahPoint) CheetahPoint {
if p.Inf {
return A_ID
}
if p.Y == F6_ZERO {
return A_ID
}
return cheetahDoubleUnsafe(&p)
}
func cheetahAddUnsafe(p, q *CheetahPoint) CheetahPoint {
slope := f6Mul(f6Sub(p.Y, q.Y), f6Inv(f6Sub(p.X, q.X)))
xOut := f6Sub(f6Square(slope), f6Add(p.X, q.X))
yOut := f6Sub(f6Mul(slope, f6Sub(p.X, xOut)), p.Y)
return CheetahPoint{
X: xOut,
Y: yOut,
Inf: false,
}
}
func cheetahDoubleUnsafe(p *CheetahPoint) CheetahPoint {
slope := f6Mul(
f6Add(f6ScalarMul(f6Square(p.X), Belt{Value: 3}), F6_ONE),
f6Inv(f6ScalarMul(p.Y, Belt{Value: 2})),
)
xOut := f6Sub(f6Square(slope), f6ScalarMul(p.X, Belt{Value: 2}))
yOut := f6Sub(f6Mul(slope, f6Sub(p.X, xOut)), p.Y)
return CheetahPoint{
X: xOut,
Y: yOut,
Inf: false,
}
}
func f6Inv(a [6]Belt) [6]Belt {
if a == F6_ZERO {
panic("Cannot invert zero")
}
aPoly := Bpoly{Value: a[:]}
b := Bpoly{
Value: []Belt{
Belt{Value: bneg(7)},
BELT_ZERO,
BELT_ZERO,
BELT_ZERO,
BELT_ZERO,
BELT_ZERO,
BELT_ONE,
},
}
d, u, _ := Egcd(aPoly, b)
dInv := d.Value[0].Inv()
res := u.ScalarMul(dInv, 6)
return [6]Belt{res.Value[0], res.Value[1], res.Value[2], res.Value[3], res.Value[4], res.Value[5]}
}
func f6Mul(a, b [6]Belt) [6]Belt {
a0b0 := karat3([3]Belt{a[0], a[1], a[2]}, [3]Belt{b[0], b[1], b[2]})
a1b1 := karat3([3]Belt{a[3], a[4], a[5]}, [3]Belt{b[3], b[4], b[5]})
foil := karat3(
[3]Belt{a[0].Add(a[3]), a[1].Add(a[4]), a[2].Add(a[5])},
[3]Belt{b[0].Add(b[3]), b[1].Add(b[4]), b[2].Add(b[5])},
)
foil_0 := foil[0].Sub(a0b0[0])
foil_1 := foil[1].Sub(a0b0[1])
foil_2 := foil[2].Sub(a0b0[2])
foil_3 := foil[3].Sub(a0b0[3])
foil_4 := foil[4].Sub(a0b0[4])
cross := [5]Belt{
foil_0.Sub(a1b1[0]),
foil_1.Sub(a1b1[1]),
foil_2.Sub(a1b1[2]),
foil_3.Sub(a1b1[3]),
foil_4.Sub(a1b1[4]),
}
seven := Belt{Value: 7}
a0b0cross0 := a0b0[3].Add(cross[0])
a0b0cross1 := a0b0[4].Add(cross[1])
return [6]Belt{
a0b0[0].Add(seven.Mul(cross[3].Add(a1b1[0]))),
a0b0[1].Add(seven.Mul(cross[4].Add(a1b1[1]))),
a0b0[2].Add(seven.Mul(a1b1[2])),
a0b0cross0.Add(seven.Mul(a1b1[3])),
a0b0cross1.Add(seven.Mul(a1b1[4])),
cross[2],
}
}
func f6ScalarMul(a [6]Belt, scalar Belt) [6]Belt {
res := F6_ZERO
for i := range res {
res[i] = a[i].Mul(scalar)
}
return res
}
func f6Add(a, b [6]Belt) [6]Belt {
res := F6_ZERO
for i := range res {
res[i] = a[i].Add(b[i])
}
return res
}
func f6Sub(a, b [6]Belt) [6]Belt {
return f6Add(a, f6Neg(b))
}
func f6Neg(a [6]Belt) [6]Belt {
res := F6_ZERO
for i := range res {
res[i] = a[i].Neg()
}
return res
}
func f6Square(a [6]Belt) [6]Belt {
return f6Mul(a, a)
}
func karat3(a, b [3]Belt) [5]Belt {
m := [3]Belt{
a[0].Mul(b[0]),
a[1].Mul(b[1]),
a[2].Mul(b[2]),
}
a0a1 := a[0].Add(a[1])
b0b1 := b[0].Add(b[1])
m0m1 := m[0].Add(m[1])
a0a1b0b1 := a0a1.Mul(b0b1)
a0a2 := a[0].Add(a[2])
b0b2 := b[0].Add(b[2])
m0m2 := m[0].Add(m[2])
m0m2m1 := m0m2.Sub(m[1])
a0a2b0b2 := a0a2.Mul(b0b2)
a1a2 := a[1].Add(a[2])
b1b2 := b[1].Add(b[2])
m1m2 := m[1].Add(m[2])
a1a2b1b2 := a1a2.Mul(b1b2)
return [5]Belt{
m[0],
a0a1b0b1.Sub(m0m1),
a0a2b0b2.Sub(m0m2m1),
a1a2b1b2.Sub(m1m2),
m[2],
}
}