diff --git a/cmd/csppsolver/solver.go b/cmd/csppsolver/solver.go index 6bf3362..de18c99 100644 --- a/cmd/csppsolver/solver.go +++ b/cmd/csppsolver/solver.go @@ -36,23 +36,38 @@ type Args struct { type Result struct { Roots []*big.Int + Exponents []int RepeatedRoot *big.Int } +func (*Solver) RootFactors(args Args, res *Result) error { + roots, exps, err := solver.RootFactors(args.A, args.F) + if err != nil { + return err + } + res.Roots = roots + res.Exponents = exps + return nil +} + type repeatedRoot interface { RepeatedRoot() *big.Int } func (*Solver) Roots(args Args, res *Result) error { - roots, err := solver.Roots(args.A, args.F) - if rr, ok := err.(repeatedRoot); ok { - res.RepeatedRoot = rr.RepeatedRoot() - return nil // error set by client package - } + roots, exps, err := solver.RootFactors(args.A, args.F) if err != nil { return err } + for i, exp := range exps { + if exp != 1 { + res.RepeatedRoot = roots[i] + return nil // error set by client package + } + } + res.Roots = roots + res.Exponents = exps return nil } diff --git a/solver/solver.go b/solver/solver.go index 025c793..1068ef2 100644 --- a/solver/solver.go +++ b/solver/solver.go @@ -35,16 +35,12 @@ func factorPoly(fac *C.fmpz_mod_poly_factor_struct, i uintptr) *C.fmpz_mod_poly_ return (*C.fmpz_mod_poly_struct)(unsafe.Pointer(uintptr(unsafe.Pointer(fac.poly)) + i*C.sizeof_fmpz_mod_poly_struct)) } -type repeatedRoot big.Int - -func (r *repeatedRoot) Error() string { return "repeated roots" } -func (r *repeatedRoot) RepeatedRoot() *big.Int { return (*big.Int)(r) } - -// Roots solves for len(a)-1 roots of the polynomial with coefficients a (mod F). -// Repeated roots are considered an error for the purposes of unique slot assignment. -func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) { +// RootFactors returns the roots and their number of solutions in the +// factorized polynomial. Repeated roots are an error in the mixing protocol +// but unlike the Roots function are not returned as an error here. +func RootFactors(a []*big.Int, F *big.Int) ([]*big.Int, []int, error) { if len(a) < 2 { - return nil, errors.New("too few coefficients") + return nil, nil, errors.New("too few coefficients") } var mod C.fmpz_t @@ -80,6 +76,7 @@ func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) { C.fmpz_mod_poly_factor(&factor[0], &poly[0], &modctx[0]) roots := make([]*big.Int, 0, len(a)-1) + exps := make([]int, 0, len(a)-1) var m C.fmpz_t C.fmpz_init(&m[0]) defer C.fmpz_clear(&m[0]) @@ -93,19 +90,36 @@ func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) { b, ok := new(big.Int).SetString(str, base) if !ok { - return nil, errors.New("failed to read fmpz") + return nil, nil, errors.New("failed to read fmpz") } b.Neg(b) b.Mod(b, F) - if factorExp(&factor[0], uintptr(i)) != 1 { - return nil, (*repeatedRoot)(b) - } roots = append(roots, b) + exps = append(exps, int(factorExp(&factor[0], uintptr(i)))) + } + + return roots, exps, nil +} + +type repeatedRoot big.Int + +func (r *repeatedRoot) Error() string { return "repeated roots" } +func (r *repeatedRoot) RepeatedRoot() *big.Int { return (*big.Int)(r) } + +// Roots solves for len(a)-1 roots of the polynomial with coefficients a (mod F). +// Repeated roots are considered an error for the purposes of unique slot, +// assignment, and an error with method RepeatedRoot() *big.Int is returned. +func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) { + roots, exps, err := RootFactors(a, F) + if err != nil { + return roots, err } - if len(roots) != len(a)-1 { - return nil, errors.New("too few roots") + for i, exp := range exps { + if exp != 1 { + return nil, (*repeatedRoot)(roots[i]) + } } return roots, nil diff --git a/solver/solver_test.go b/solver/solver_test.go index 80d2650..1734fe4 100644 --- a/solver/solver_test.go +++ b/solver/solver_test.go @@ -429,6 +429,33 @@ func TestRoots(t *testing.T) { } } +func TestRootFactors(t *testing.T) { + for i := range tests { + roots, exps, err := RootFactors(tests[i].coeffs, tests[i].field) + if err != nil { + t.Error(err) + continue + } + for i, exp := range exps { + if exp != 1 { + t.Errorf("repeated root %v at index %v", roots[i], i) + continue + } + } + if len(roots) != len(tests[i].messages) { + t.Error("wrong root count") + continue + } + sortBig(tests[i].messages) + sortBig(roots) + for j := range roots { + if roots[j].Cmp(tests[i].messages[j]) != 0 { + t.Error("recovered wrong message") + } + } + } +} + func BenchmarkRoots(b *testing.B) { for i := range tests { b.Run(fmt.Sprintf("%d", tests[i].n), func(b *testing.B) { diff --git a/solverrpc/rpc.go b/solverrpc/rpc.go index 9874396..225f431 100644 --- a/solverrpc/rpc.go +++ b/solverrpc/rpc.go @@ -77,6 +77,31 @@ func StartSolver() error { return onceErr } +// RootFactors returns the roots and their number of solutions in the +// factorized polynomial. Repeated roots are an error in the mixing protocol +// but unlike the Roots function are not returned as an error here. +func RootFactors(a []*big.Int, F *big.Int) ([]*big.Int, []int, error) { + if err := StartSolver(); err != nil { + return nil, nil, err + } + + var args struct { + A []*big.Int + F *big.Int + } + args.A = a + args.F = F + var result struct { + Roots []*big.Int + Exponents []int + } + err := client.Call("Solver.RootFactors", args, &result) + if err != nil { + return nil, nil, err + } + return result.Roots, result.Exponents, nil +} + type repeatedRoot big.Int func (r *repeatedRoot) Error() string { return "repeated roots" } diff --git a/solverrpc/solver_test.go b/solverrpc/solver_test.go index e842fd1..d75e193 100644 --- a/solverrpc/solver_test.go +++ b/solverrpc/solver_test.go @@ -429,6 +429,33 @@ func TestRoots(t *testing.T) { } } +func TestRootFactors(t *testing.T) { + for i := range tests { + roots, exps, err := RootFactors(tests[i].coeffs, tests[i].field) + if err != nil { + t.Error(err) + continue + } + for i, exp := range exps { + if exp != 1 { + t.Errorf("repeated root %v at index %v", roots[i], i) + continue + } + } + if len(roots) != len(tests[i].messages) { + t.Error("wrong root count") + continue + } + sortBig(tests[i].messages) + sortBig(roots) + for j := range roots { + if roots[j].Cmp(tests[i].messages[j]) != 0 { + t.Error("recovered wrong message") + } + } + } +} + func BenchmarkRoots(b *testing.B) { for i := range tests { b.Run(fmt.Sprintf("%d", tests[i].n), func(b *testing.B) {