package crypto import ( "encoding/binary" "fmt" "math/big" "slices" ) var ( BELT_ZERO = Belt{Value: 0} BELT_ONE = Belt{Value: 1} F6_ZERO = [6]Belt{BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO} F6_ONE = [6]Belt{BELT_ONE, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO} A_ID = CheetahPoint{ X: F6_ZERO, Y: F6_ONE, Inf: true, } A_GEN = CheetahPoint{ X: [6]Belt{ {Value: 2754611494552410273}, {Value: 8599518745794843693}, {Value: 10526511002404673680}, {Value: 4830863958577994148}, {Value: 375185138577093320}, {Value: 12938930721685970739}, }, Y: [6]Belt{ {Value: 15384029202802550068}, {Value: 2774812795997841935}, {Value: 14375303400746062753}, {Value: 10708493419890101954}, {Value: 13187678623570541764}, {Value: 9990732138772505951}, }, Inf: false, } P_BIG = new(big.Int).SetUint64(PRIME) P_BIG_2 = new(big.Int).Mul(P_BIG, P_BIG) P_BIG_3 = new(big.Int).Mul(P_BIG_2, P_BIG) G_ORDER, _ = new(big.Int).SetString("55610362957290864006699123731285679659474893560816383126640993521607086746831", 10) ) type CheetahPoint struct { X, Y [6]Belt Inf bool } func (p *CheetahPoint) Bytes() []byte { bytes := []byte{0x1} for i := 5; i >= 0; i-- { belt := p.Y[i].Value buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, belt) bytes = append(bytes, buf...) } for i := 5; i >= 0; i-- { belt := p.X[i].Value buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, belt) bytes = append(bytes, buf...) } return bytes } func CheetaPointFromBytes(bytes []byte) (CheetahPoint, error) { if len(bytes) != 97 { return CheetahPoint{}, fmt.Errorf("invalid point length %d, length must be 97", len(bytes)) } bytes = bytes[1:] if len(bytes)%8 != 0 { return CheetahPoint{}, fmt.Errorf("input length not multiple of 8") } n := len(bytes) / 8 belts := []Belt{} for i := 0; i < n; i++ { chunk := bytes[i*8 : (i+1)*8] belt := binary.BigEndian.Uint64(chunk) belts = append(belts, Belt{Value: belt}) } slices.Reverse(belts) point := CheetahPoint{ X: [6]Belt(belts[0:6]), Y: [6]Belt(belts[6:]), Inf: false, } if point.InCurve() { return point, nil } return CheetahPoint{}, fmt.Errorf("point not in curve") } func (p *CheetahPoint) InCurve() bool { if *p == A_ID { return true } scaled := CheetahScaleBig(*p, *G_ORDER) return scaled == A_ID } func CheetahScaleBig(p CheetahPoint, n big.Int) CheetahPoint { zero := big.NewInt(0) nCopy := new(big.Int).Set(&n) acc := A_ID for nCopy.Cmp(zero) > 0 { if nCopy.Bit(0) == 1 { acc = CheetahAdd(acc, p) } p = CheetahDouble(p) nCopy.Rsh(nCopy, 1) } return acc } func CheetahAdd(p, q CheetahPoint) CheetahPoint { if p.Inf { return q } if q.Inf { return p } if p == CheetahNeg(q) { return A_ID } if p == q { return CheetahDouble(p) } return cheetahAddUnsafe(&p, &q) } func CheetahNeg(p CheetahPoint) CheetahPoint { negP := p for i := 0; i < 6; i++ { negP.Y[i] = p.Y[i].Neg() } return negP } func CheetahDouble(p CheetahPoint) CheetahPoint { if p.Inf { return A_ID } if p.Y == F6_ZERO { return A_ID } return cheetahDoubleUnsafe(&p) } func cheetahAddUnsafe(p, q *CheetahPoint) CheetahPoint { slope := f6Mul(f6Sub(p.Y, q.Y), f6Inv(f6Sub(p.X, q.X))) xOut := f6Sub(f6Square(slope), f6Add(p.X, q.X)) yOut := f6Sub(f6Mul(slope, f6Sub(p.X, xOut)), p.Y) return CheetahPoint{ X: xOut, Y: yOut, Inf: false, } } func cheetahDoubleUnsafe(p *CheetahPoint) CheetahPoint { slope := f6Mul( f6Add(f6ScalarMul(f6Square(p.X), Belt{Value: 3}), F6_ONE), f6Inv(f6ScalarMul(p.Y, Belt{Value: 2})), ) xOut := f6Sub(f6Square(slope), f6ScalarMul(p.X, Belt{Value: 2})) yOut := f6Sub(f6Mul(slope, f6Sub(p.X, xOut)), p.Y) return CheetahPoint{ X: xOut, Y: yOut, Inf: false, } } func f6Inv(a [6]Belt) [6]Belt { if a == F6_ZERO { panic("Cannot invert zero") } aPoly := Bpoly{Value: a[:]} b := Bpoly{ Value: []Belt{ Belt{Value: bneg(7)}, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ZERO, BELT_ONE, }, } d, u, _ := Egcd(aPoly, b) dInv := d.Value[0].Inv() res := u.ScalarMul(dInv, 6) return [6]Belt{res.Value[0], res.Value[1], res.Value[2], res.Value[3], res.Value[4], res.Value[5]} } func f6Mul(a, b [6]Belt) [6]Belt { a0b0 := karat3([3]Belt{a[0], a[1], a[2]}, [3]Belt{b[0], b[1], b[2]}) a1b1 := karat3([3]Belt{a[3], a[4], a[5]}, [3]Belt{b[3], b[4], b[5]}) foil := karat3( [3]Belt{a[0].Add(a[3]), a[1].Add(a[4]), a[2].Add(a[5])}, [3]Belt{b[0].Add(b[3]), b[1].Add(b[4]), b[2].Add(b[5])}, ) foil_0 := foil[0].Sub(a0b0[0]) foil_1 := foil[1].Sub(a0b0[1]) foil_2 := foil[2].Sub(a0b0[2]) foil_3 := foil[3].Sub(a0b0[3]) foil_4 := foil[4].Sub(a0b0[4]) cross := [5]Belt{ foil_0.Sub(a1b1[0]), foil_1.Sub(a1b1[1]), foil_2.Sub(a1b1[2]), foil_3.Sub(a1b1[3]), foil_4.Sub(a1b1[4]), } seven := Belt{Value: 7} a0b0cross0 := a0b0[3].Add(cross[0]) a0b0cross1 := a0b0[4].Add(cross[1]) return [6]Belt{ a0b0[0].Add(seven.Mul(cross[3].Add(a1b1[0]))), a0b0[1].Add(seven.Mul(cross[4].Add(a1b1[1]))), a0b0[2].Add(seven.Mul(a1b1[2])), a0b0cross0.Add(seven.Mul(a1b1[3])), a0b0cross1.Add(seven.Mul(a1b1[4])), cross[2], } } func f6ScalarMul(a [6]Belt, scalar Belt) [6]Belt { res := F6_ZERO for i := range res { res[i] = a[i].Mul(scalar) } return res } func f6Add(a, b [6]Belt) [6]Belt { res := F6_ZERO for i := range res { res[i] = a[i].Add(b[i]) } return res } func f6Sub(a, b [6]Belt) [6]Belt { return f6Add(a, f6Neg(b)) } func f6Neg(a [6]Belt) [6]Belt { res := F6_ZERO for i := range res { res[i] = a[i].Neg() } return res } func f6Square(a [6]Belt) [6]Belt { return f6Mul(a, a) } func karat3(a, b [3]Belt) [5]Belt { m := [3]Belt{ a[0].Mul(b[0]), a[1].Mul(b[1]), a[2].Mul(b[2]), } a0a1 := a[0].Add(a[1]) b0b1 := b[0].Add(b[1]) m0m1 := m[0].Add(m[1]) a0a1b0b1 := a0a1.Mul(b0b1) a0a2 := a[0].Add(a[2]) b0b2 := b[0].Add(b[2]) m0m2 := m[0].Add(m[2]) m0m2m1 := m0m2.Sub(m[1]) a0a2b0b2 := a0a2.Mul(b0b2) a1a2 := a[1].Add(a[2]) b1b2 := b[1].Add(b[2]) m1m2 := m[1].Add(m[2]) a1a2b1b2 := a1a2.Mul(b1b2) return [5]Belt{ m[0], a0a1b0b1.Sub(m0m1), a0a2b0b2.Sub(m0m2m1), a1a2b1b2.Sub(m1m2), m[2], } }