package crypto import ( "fmt" "math/big" "math/bits" ) const ( // Base field arithmetic functions. PRIME uint64 = 18446744069414584321 PRIME_PRIME uint64 = PRIME - 2 // R_MOD_P uint64 = 0xFFFF_FFFF; H uint64 = 20033703337 ORDER uint64 = 1 << 32 ) var ROOTS = []uint64{ 0x0000000000000001, 0xffffffff00000000, 0x0001000000000000, 0xfffffffeff000001, 0xefffffff00000001, 0x00003fffffffc000, 0x0000008000000000, 0xf80007ff08000001, 0xbf79143ce60ca966, 0x1905d02a5c411f4e, 0x9d8f2ad78bfed972, 0x0653b4801da1c8cf, 0xf2c35199959dfcb6, 0x1544ef2335d17997, 0xe0ee099310bba1e2, 0xf6b2cffe2306baac, 0x54df9630bf79450e, 0xabd0a6e8aa3d8a0e, 0x81281a7b05f9beac, 0xfbd41c6b8caa3302, 0x30ba2ecd5e93e76d, 0xf502aef532322654, 0x4b2a18ade67246b5, 0xea9d5a1336fbc98b, 0x86cdcc31c307e171, 0x4bbaf5976ecfefd8, 0xed41d05b78d6e286, 0x10d78dd8915a171d, 0x59049500004a4485, 0xdfa8c93ba46d2666, 0x7e9bd009b86a0845, 0x400a7f755588e659, 0x185629dcda58878c, } var ( PRIME_128, _ = new(big.Int).SetString("18446744069414584321", 10) RP, _ = new(big.Int).SetString("340282366841710300967557013911933812736", 10) R2, _ = new(big.Int).SetString("18446744065119617025", 10) ) type Belt struct { Value uint64 } func BaseCheck(x uint64) bool { return x < PRIME } func ZeroBelt() Belt { return Belt{Value: 0} } func OneBelt() Belt { return Belt{Value: 1} } func (belt *Belt) IsZero() bool { return belt.Value == 0 } func (belt *Belt) IsOne() bool { return belt.Value == 1 } func (belt *Belt) OrderedRoot() (Belt, error) { // Belt(bpow(H, ORDER / self.0)) log2 := bits.Len(uint(belt.Value)) - 1 if log2 > len(ROOTS) { return Belt{}, fmt.Errorf("ordered_root: out of bounds") } // assert that it was an even power of two if (1 << log2) != belt.Value { return Belt{}, fmt.Errorf("ordered_root: not a power of two") } return Belt{Value: ROOTS[log2]}, nil } func (belt *Belt) Inv() Belt { return Belt{Value: binv(belt.Value)} } func (belt *Belt) Add(other Belt) Belt { a := belt.Value b := other.Value return Belt{Value: badd(a, b)} } func (belt *Belt) Sub(other Belt) Belt { a := belt.Value b := other.Value return Belt{Value: bsub(a, b)} } func (belt *Belt) Neg() Belt { return Belt{Value: bneg(belt.Value)} } func (belt *Belt) Mul(other Belt) Belt { a := belt.Value b := other.Value return Belt{Value: bmul(a, b)} } func (belt *Belt) Eq(other Belt) bool { return belt.Value == other.Value } func (belt *Belt) Pow(exp uint64) Belt { return Belt{Value: bpow(belt.Value, exp)} } func (belt *Belt) Div(other Belt) Belt { return belt.Mul(other.Inv()) } func binv(a uint64) uint64 { if !BaseCheck(a) { fmt.Println("element must be inside the field") } y := montify(a) y2 := montiply(y, montiply(y, y)) y3 := montiply(y, montiply(y2, y2)) y5 := montiply(y2, montwopow(y3, 2)) y10 := montiply(y5, montwopow(y5, 5)) y20 := montiply(y10, montwopow(y10, 10)) y30 := montiply(y10, montwopow(y20, 10)) y31 := montiply(y, montiply(y30, y30)) dup := montiply(montwopow(y31, 32), y31) res := new(big.Int) res.SetUint64(montiply(y, montiply(dup, dup))) return montReduction(res) } func badd(a, b uint64) uint64 { if !BaseCheck(a) || !BaseCheck(b) { fmt.Println("element must be inside the field") } b = PRIME - b r, c := overflowingSub(a, b) adj := uint32(0) if c { adj = adj - 1 } return r - uint64(adj) } func bsub(a, b uint64) uint64 { if !BaseCheck(a) || !BaseCheck(b) { fmt.Println("element must be inside the field") } r, c := overflowingSub(a, b) adj := uint32(0) if c { adj = adj - 1 } return r - uint64(adj) } func bneg(a uint64) uint64 { if !BaseCheck(a) { fmt.Println("element must be inside the field") } if a != 0 { return PRIME - a } else { return 0 } } func bmul(a, b uint64) uint64 { if !BaseCheck(a) || !BaseCheck(b) { fmt.Println("element must be inside the field") } aBig := new(big.Int).SetUint64(a) bBig := new(big.Int).SetUint64(b) return reduce(aBig.Mul(aBig, bBig)) } func bpow(a, b uint64) uint64 { if !BaseCheck(a) || !BaseCheck(b) { fmt.Println("element must be inside the field") } c := uint64(1) if b == 0 { return c } for b > 1 { if b&1 == 0 { a = bmul(a, a) b >>= 1 } else { c = bmul(c, a) a = bmul(a, a) b = (b - 1) >> 1 } } return bmul(a, c) } func overflowingSub(a, b uint64) (uint64, bool) { res := a - b overflow := a < b return res, overflow } func overflowingAdd(a, b uint64) (uint64, bool) { res := a + b overflow := res < a || res < b return res, overflow } func montify(a uint64) uint64 { if !BaseCheck(a) { fmt.Println("element must be inside the field") } aBig := new(big.Int).SetUint64(a) return montReduction(aBig.Mul(aBig, R2)) } func montiply(a, b uint64) uint64 { if !BaseCheck(a) || !BaseCheck(b) { fmt.Println("element must be inside the field") } aBig := new(big.Int).SetUint64(a) bBig := new(big.Int).SetUint64(b) return montReduction(aBig.Mul(aBig, bBig)) } func montwopow(a uint64, exp int) uint64 { if !BaseCheck(a) { fmt.Println("element must be inside the field") } res := a for i := 0; i < exp; i++ { res = montiply(res, res) } return res } func montReduction(a *big.Int) uint64 { if a.Cmp(RP) >= 0 { fmt.Println("element must be inside the field") } x1 := new(big.Int) x1.And(x1.Rsh(a, 32), new(big.Int).SetUint64(0xFFFFFFFF)) x2 := new(big.Int) x2.Rsh(a, 64) x0 := new(big.Int) x0.And(a, new(big.Int).SetUint64(0xFFFFFFFF)) c := new(big.Int) c.Lsh(x0.Add(x0, x1), 32) f := new(big.Int) f.Rsh(c, 64) d := new(big.Int) d.Sub(c, d.Add(x1, d.Mul(f, PRIME_128))) if x2.Cmp(d) >= 0 { ret := new(big.Int) ret.Sub(x2, d) return ret.Uint64() } else { ret := new(big.Int) ret.Sub(ret.Add(x2, PRIME_128), d) return ret.Uint64() } } func reduce(a *big.Int) uint64 { low := new(big.Int).And(a, new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF)) mid := new(big.Int) mid.And(mid.Rsh(a, 64), big.NewInt(0xFFFFFFFF)) high := new(big.Int) high.Rsh(a, 96) low2, carry := overflowingSub(low.Uint64(), high.Uint64()) if carry { low2 = low2 + PRIME } product := mid.Uint64() << 32 product -= product >> 32 result, carry := overflowingAdd(low2, product) if carry { result = result - PRIME } if result >= PRIME { result -= PRIME } return result }