303 lines
6.2 KiB
Go
303 lines
6.2 KiB
Go
package crypto
|
|
|
|
import (
|
|
"fmt"
|
|
"math/big"
|
|
"math/bits"
|
|
)
|
|
|
|
const (
|
|
// Base field arithmetic functions.
|
|
PRIME uint64 = 18446744069414584321
|
|
PRIME_PRIME uint64 = PRIME - 2
|
|
// R_MOD_P uint64 = 0xFFFF_FFFF;
|
|
H uint64 = 20033703337
|
|
ORDER uint64 = 1 << 32
|
|
)
|
|
|
|
var ROOTS = []uint64{
|
|
0x0000000000000001, 0xffffffff00000000, 0x0001000000000000, 0xfffffffeff000001,
|
|
0xefffffff00000001, 0x00003fffffffc000, 0x0000008000000000, 0xf80007ff08000001,
|
|
0xbf79143ce60ca966, 0x1905d02a5c411f4e, 0x9d8f2ad78bfed972, 0x0653b4801da1c8cf,
|
|
0xf2c35199959dfcb6, 0x1544ef2335d17997, 0xe0ee099310bba1e2, 0xf6b2cffe2306baac,
|
|
0x54df9630bf79450e, 0xabd0a6e8aa3d8a0e, 0x81281a7b05f9beac, 0xfbd41c6b8caa3302,
|
|
0x30ba2ecd5e93e76d, 0xf502aef532322654, 0x4b2a18ade67246b5, 0xea9d5a1336fbc98b,
|
|
0x86cdcc31c307e171, 0x4bbaf5976ecfefd8, 0xed41d05b78d6e286, 0x10d78dd8915a171d,
|
|
0x59049500004a4485, 0xdfa8c93ba46d2666, 0x7e9bd009b86a0845, 0x400a7f755588e659,
|
|
0x185629dcda58878c,
|
|
}
|
|
|
|
var (
|
|
PRIME_128, _ = new(big.Int).SetString("18446744069414584321", 10)
|
|
RP, _ = new(big.Int).SetString("340282366841710300967557013911933812736", 10)
|
|
R2, _ = new(big.Int).SetString("18446744065119617025", 10)
|
|
)
|
|
|
|
type Belt struct {
|
|
Value uint64
|
|
}
|
|
|
|
func BaseCheck(x uint64) bool {
|
|
return x < PRIME
|
|
}
|
|
|
|
func ZeroBelt() Belt {
|
|
return Belt{Value: 0}
|
|
}
|
|
|
|
func OneBelt() Belt {
|
|
return Belt{Value: 1}
|
|
}
|
|
|
|
func (belt *Belt) IsZero() bool {
|
|
return belt.Value == 0
|
|
}
|
|
|
|
func (belt *Belt) IsOne() bool {
|
|
return belt.Value == 1
|
|
}
|
|
|
|
func (belt *Belt) OrderedRoot() (Belt, error) {
|
|
// Belt(bpow(H, ORDER / self.0))
|
|
log2 := bits.Len(uint(belt.Value)) - 1
|
|
if log2 > len(ROOTS) {
|
|
return Belt{}, fmt.Errorf("ordered_root: out of bounds")
|
|
}
|
|
// assert that it was an even power of two
|
|
if (1 << log2) != belt.Value {
|
|
return Belt{}, fmt.Errorf("ordered_root: not a power of two")
|
|
}
|
|
return Belt{Value: ROOTS[log2]}, nil
|
|
}
|
|
|
|
func (belt *Belt) Inv() Belt {
|
|
return Belt{Value: binv(belt.Value)}
|
|
}
|
|
|
|
func (belt *Belt) Add(other Belt) Belt {
|
|
a := belt.Value
|
|
b := other.Value
|
|
return Belt{Value: badd(a, b)}
|
|
}
|
|
|
|
func (belt *Belt) Sub(other Belt) Belt {
|
|
a := belt.Value
|
|
b := other.Value
|
|
return Belt{Value: bsub(a, b)}
|
|
}
|
|
|
|
func (belt *Belt) Neg() Belt {
|
|
return Belt{Value: bneg(belt.Value)}
|
|
}
|
|
|
|
func (belt *Belt) Mul(other Belt) Belt {
|
|
a := belt.Value
|
|
b := other.Value
|
|
return Belt{Value: bmul(a, b)}
|
|
}
|
|
|
|
func (belt *Belt) Eq(other Belt) bool {
|
|
return belt.Value == other.Value
|
|
}
|
|
|
|
func (belt *Belt) Pow(exp uint64) Belt {
|
|
return Belt{Value: bpow(belt.Value, exp)}
|
|
}
|
|
|
|
func (belt *Belt) Div(other Belt) Belt {
|
|
return belt.Mul(other.Inv())
|
|
}
|
|
|
|
func binv(a uint64) uint64 {
|
|
if !BaseCheck(a) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
y := montify(a)
|
|
y2 := montiply(y, montiply(y, y))
|
|
y3 := montiply(y, montiply(y2, y2))
|
|
y5 := montiply(y2, montwopow(y3, 2))
|
|
y10 := montiply(y5, montwopow(y5, 5))
|
|
y20 := montiply(y10, montwopow(y10, 10))
|
|
y30 := montiply(y10, montwopow(y20, 10))
|
|
y31 := montiply(y, montiply(y30, y30))
|
|
dup := montiply(montwopow(y31, 32), y31)
|
|
|
|
res := new(big.Int)
|
|
res.SetUint64(montiply(y, montiply(dup, dup)))
|
|
return montReduction(res)
|
|
}
|
|
|
|
func badd(a, b uint64) uint64 {
|
|
if !BaseCheck(a) || !BaseCheck(b) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
|
|
b = PRIME - b
|
|
r, c := overflowingSub(a, b)
|
|
|
|
adj := uint32(0)
|
|
if c {
|
|
adj = adj - 1
|
|
}
|
|
return r - uint64(adj)
|
|
}
|
|
|
|
func bsub(a, b uint64) uint64 {
|
|
if !BaseCheck(a) || !BaseCheck(b) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
|
|
r, c := overflowingSub(a, b)
|
|
adj := uint32(0)
|
|
if c {
|
|
adj = adj - 1
|
|
}
|
|
|
|
return r - uint64(adj)
|
|
}
|
|
|
|
func bneg(a uint64) uint64 {
|
|
if !BaseCheck(a) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
|
|
if a != 0 {
|
|
return PRIME - a
|
|
} else {
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func bmul(a, b uint64) uint64 {
|
|
if !BaseCheck(a) || !BaseCheck(b) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
aBig := new(big.Int).SetUint64(a)
|
|
bBig := new(big.Int).SetUint64(b)
|
|
return reduce(aBig.Mul(aBig, bBig))
|
|
}
|
|
|
|
func bpow(a, b uint64) uint64 {
|
|
if !BaseCheck(a) || !BaseCheck(b) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
c := uint64(1)
|
|
if b == 0 {
|
|
return c
|
|
}
|
|
|
|
for b > 1 {
|
|
if b&1 == 0 {
|
|
a = bmul(a, a)
|
|
b >>= 1
|
|
} else {
|
|
c = bmul(c, a)
|
|
a = bmul(a, a)
|
|
b = (b - 1) >> 1
|
|
}
|
|
}
|
|
|
|
return bmul(a, c)
|
|
}
|
|
|
|
func overflowingSub(a, b uint64) (uint64, bool) {
|
|
res := a - b
|
|
overflow := a < b
|
|
return res, overflow
|
|
}
|
|
|
|
func overflowingAdd(a, b uint64) (uint64, bool) {
|
|
res := a + b
|
|
overflow := res < a || res < b
|
|
return res, overflow
|
|
}
|
|
|
|
func montify(a uint64) uint64 {
|
|
if !BaseCheck(a) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
aBig := new(big.Int).SetUint64(a)
|
|
return montReduction(aBig.Mul(aBig, R2))
|
|
}
|
|
|
|
func montiply(a, b uint64) uint64 {
|
|
if !BaseCheck(a) || !BaseCheck(b) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
aBig := new(big.Int).SetUint64(a)
|
|
bBig := new(big.Int).SetUint64(b)
|
|
return montReduction(aBig.Mul(aBig, bBig))
|
|
}
|
|
|
|
func montwopow(a uint64, exp int) uint64 {
|
|
if !BaseCheck(a) {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
|
|
res := a
|
|
for i := 0; i < exp; i++ {
|
|
res = montiply(res, res)
|
|
}
|
|
return res
|
|
}
|
|
func montReduction(a *big.Int) uint64 {
|
|
if a.Cmp(RP) >= 0 {
|
|
fmt.Println("element must be inside the field")
|
|
}
|
|
|
|
x1 := new(big.Int)
|
|
x1.And(x1.Rsh(a, 32), new(big.Int).SetUint64(0xFFFFFFFF))
|
|
|
|
x2 := new(big.Int)
|
|
x2.Rsh(a, 64)
|
|
|
|
x0 := new(big.Int)
|
|
x0.And(a, new(big.Int).SetUint64(0xFFFFFFFF))
|
|
|
|
c := new(big.Int)
|
|
c.Lsh(x0.Add(x0, x1), 32)
|
|
|
|
f := new(big.Int)
|
|
f.Rsh(c, 64)
|
|
|
|
d := new(big.Int)
|
|
d.Sub(c, d.Add(x1, d.Mul(f, PRIME_128)))
|
|
if x2.Cmp(d) >= 0 {
|
|
ret := new(big.Int)
|
|
ret.Sub(x2, d)
|
|
return ret.Uint64()
|
|
} else {
|
|
ret := new(big.Int)
|
|
ret.Sub(ret.Add(x2, PRIME_128), d)
|
|
return ret.Uint64()
|
|
}
|
|
}
|
|
|
|
func reduce(a *big.Int) uint64 {
|
|
low := new(big.Int).And(a, new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF))
|
|
|
|
mid := new(big.Int)
|
|
mid.And(mid.Rsh(a, 64), big.NewInt(0xFFFFFFFF))
|
|
|
|
high := new(big.Int)
|
|
high.Rsh(a, 96)
|
|
|
|
low2, carry := overflowingSub(low.Uint64(), high.Uint64())
|
|
if carry {
|
|
low2 = low2 + PRIME
|
|
}
|
|
|
|
product := mid.Uint64() << 32
|
|
product -= product >> 32
|
|
|
|
result, carry := overflowingAdd(low2, product)
|
|
if carry {
|
|
result = result - PRIME
|
|
}
|
|
|
|
if result >= PRIME {
|
|
result -= PRIME
|
|
}
|
|
return result
|
|
}
|