Skip to content

Commit

Permalink
Add GetStatesInfo method to PubSignals
Browse files Browse the repository at this point in the history
  • Loading branch information
olomix committed Jan 16, 2025
1 parent 48c5013 commit a9cdd15
Show file tree
Hide file tree
Showing 16 changed files with 340 additions and 15 deletions.
13 changes: 13 additions & 0 deletions authV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,16 @@ func (a *AuthV2PubSignals) PubSignalsUnmarshal(data []byte) error {
func (a AuthV2PubSignals) GetObjMap() map[string]interface{} {
return toMap(a)
}

func (a AuthV2PubSignals) GetStatesInfo() (StatesInfo, error) {
if a.UserID == nil {
return StatesInfo{}, errors.New(ErrorEmptyID)
}
if a.GISTRoot == nil {
return StatesInfo{}, errors.New(ErrorEmptyStateHash)
}
return StatesInfo{
States: []State{},
Gists: []Gist{{ID: *a.UserID, Root: *a.GISTRoot}},
}, nil
}
34 changes: 24 additions & 10 deletions authV2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
it "github.com/iden3/go-circuits/v2/testing"
core "github.com/iden3/go-iden3-core/v2"
"github.com/iden3/go-merkletree-sql/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -161,29 +160,44 @@ func TestAuthV2Circuit_CircuitUnmarshal(t *testing.T) {
// generate mock Data.
intID, b := new(big.Int).SetString("19224224881555258540966250468059781351205177043309252290095510834143232000",
10)
assert.True(t, b)
require.True(t, b)
identifier, err := core.IDFromInt(intID)
assert.Nil(t, err)
require.NoError(t, err)

challenge := big.NewInt(1)

stateInt, b := new(big.Int).SetString(
"18656147546666944484453899241916469544090258810192803949522794490493271005313",
10)
assert.True(t, b)
require.True(t, b)
state, err := merkletree.NewHashFromBigInt(stateInt)
assert.NoError(t, err)
require.NoError(t, err)

out := []string{identifier.BigInt().String(), challenge.String(), state.BigInt().String()}
bytesOut, err := json.Marshal(out)
assert.NoError(t, err)
require.NoError(t, err)

ao := AuthV2PubSignals{}
err = ao.PubSignalsUnmarshal(bytesOut)
assert.NoError(t, err)
assert.Equal(t, challenge, ao.Challenge)
assert.Equal(t, state, ao.GISTRoot)
assert.Equal(t, &identifier, ao.UserID)
require.NoError(t, err)
require.Equal(t, challenge, ao.Challenge)
require.Equal(t, state, ao.GISTRoot)
require.Equal(t, &identifier, ao.UserID)

statesInfo, err := ao.GetStatesInfo()
require.NoError(t, err)
wantStatesInfo := StatesInfo{
States: []State{},
Gists: []Gist{
{
ID: idFromInt("19224224881555258540966250468059781351205177043309252290095510834143232000"),
Root: hashFromInt("18656147546666944484453899241916469544090258810192803949522794490493271005313"),
},
},
}
j, err := json.Marshal(statesInfo)
require.NoError(t, err)
require.Equal(t, wantStatesInfo, statesInfo, string(j))
}

func GetTreeState(t testing.TB, it *it.IdentityTest) TreeState {
Expand Down
10 changes: 8 additions & 2 deletions circuits.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,18 @@ type PubSignals interface {
PubSignalsMapper
}

// PublicStatesInfoProvider interface should be implemented by types that can return
// states info
// PublicStatesInfoProvider interface should be implemented by Inputs types that
// can return public states info
type PublicStatesInfoProvider interface {
GetPublicStatesInfo() (StatesInfo, error)
}

// StatesInfoProvider interface should be implemented by PubSignals types that
// can return states info
type StatesInfoProvider interface {
GetStatesInfo() (StatesInfo, error)
}

// StatesInfo struct. A collection of states and gists
type StatesInfo struct {
States []State `json:"states"`
Expand Down
25 changes: 25 additions & 0 deletions credentialAtomicQueryMTPV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,28 @@ func (ao *AtomicQueryMTPV2PubSignals) PubSignalsUnmarshal(data []byte) error {
func (ao AtomicQueryMTPV2PubSignals) GetObjMap() map[string]interface{} {
return toMap(ao)
}

func (ao AtomicQueryMTPV2PubSignals) GetStatesInfo() (StatesInfo, error) {
if ao.IssuerID == nil {
return StatesInfo{}, errors.New(ErrorEmptyID)
}

if ao.IssuerClaimIdenState == nil || ao.IssuerClaimNonRevState == nil {
return StatesInfo{}, errors.New(ErrorEmptyStateHash)
}

states := []State{
{
ID: *ao.IssuerID,
State: *ao.IssuerClaimIdenState,
},
}
if *ao.IssuerClaimNonRevState != *ao.IssuerClaimIdenState {
states = append(states, State{
ID: *ao.IssuerID,
State: *ao.IssuerClaimNonRevState,
})
}

return StatesInfo{States: states, Gists: []Gist{}}, nil
}
29 changes: 29 additions & 0 deletions credentialAtomicQueryMTPV2OnChain.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,32 @@ func (ao *AtomicQueryMTPV2OnChainPubSignals) PubSignalsUnmarshal(data []byte) er
func (ao AtomicQueryMTPV2OnChainPubSignals) GetObjMap() map[string]interface{} {
return toMap(ao)
}

func (ao AtomicQueryMTPV2OnChainPubSignals) GetStatesInfo() (StatesInfo, error) {
if ao.IssuerID == nil || ao.UserID == nil {
return StatesInfo{}, errors.New(ErrorEmptyID)
}

if ao.IssuerClaimIdenState == nil || ao.IssuerClaimNonRevState == nil ||
ao.GlobalRoot == nil {
return StatesInfo{}, errors.New(ErrorEmptyStateHash)
}

states := []State{
{
ID: *ao.IssuerID,
State: *ao.IssuerClaimIdenState,
},
}
if *ao.IssuerClaimNonRevState != *ao.IssuerClaimIdenState {
states = append(states, State{
ID: *ao.IssuerID,
State: *ao.IssuerClaimNonRevState,
})
}

return StatesInfo{
States: states,
Gists: []Gist{{ID: *ao.UserID, Root: *ao.GlobalRoot}},
}, nil
}
20 changes: 20 additions & 0 deletions credentialAtomicQueryMTPV2OnChain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,24 @@ func TestAtomicQueryMTPVOnChain2Outputs_CircuitUnmarshal(t *testing.T) {
require.NoError(t, err)

require.JSONEq(t, string(jsonExp), string(jsonOut))

statesInfo, err := exp.GetStatesInfo()
require.NoError(t, err)
wantStatesInfo := StatesInfo{
States: []State{
{
ID: idFromInt("27918766665310231445021466320959318414450284884582375163563581940319453185"),
State: hashFromInt("19157496396839393206871475267813888069926627705277243727237933406423274512449"),
},
},
Gists: []Gist{
{
ID: idFromInt("26109404700696283154998654512117952420503675471097392618762221546565140481"),
Root: hashFromInt("11098939821764568131087645431296528907277253709936443029379587475821759259406"),
},
},
}
j, err := json.Marshal(statesInfo)
require.NoError(t, err)
require.Equal(t, wantStatesInfo, statesInfo, string(j))
}
15 changes: 15 additions & 0 deletions credentialAtomicQueryMTPV2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,19 @@ func TestAtomicQueryMTPV2Outputs_CircuitUnmarshal(t *testing.T) {
require.NoError(t, err)

require.JSONEq(t, string(jsonExp), string(jsonOut))

statesInfo, err := exp.GetStatesInfo()
require.NoError(t, err)
wantStatesInfo := StatesInfo{
States: []State{
{
ID: idFromInt("23528770672049181535970744460798517976688641688582489375761566420828291073"),
State: hashFromInt("5687720250943511874245715094520098014548846873346473635855112185560372332782"),
},
},
Gists: []Gist{},
}
j, err := json.Marshal(statesInfo)
require.NoError(t, err)
require.Equal(t, wantStatesInfo, statesInfo, string(j))
}
25 changes: 25 additions & 0 deletions credentialAtomicQuerySigV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,28 @@ func (ao *AtomicQuerySigV2PubSignals) PubSignalsUnmarshal(data []byte) error {
func (ao AtomicQuerySigV2PubSignals) GetObjMap() map[string]interface{} {
return toMap(ao)
}

func (ao AtomicQuerySigV2PubSignals) GetStatesInfo() (StatesInfo, error) {
if ao.IssuerID == nil {
return StatesInfo{}, errors.New(ErrorEmptyID)
}

if ao.IssuerAuthState == nil || ao.IssuerClaimNonRevState == nil {
return StatesInfo{}, errors.New(ErrorEmptyStateHash)
}

states := []State{
{
ID: *ao.IssuerID,
State: *ao.IssuerAuthState,
},
}
if *ao.IssuerClaimNonRevState != *ao.IssuerAuthState {
states = append(states, State{
ID: *ao.IssuerID,
State: *ao.IssuerClaimNonRevState,
})
}

return StatesInfo{States: states, Gists: []Gist{}}, nil
}
29 changes: 29 additions & 0 deletions credentialAtomicQuerySigV2OnChain.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,32 @@ func (ao *AtomicQuerySigV2OnChainPubSignals) PubSignalsUnmarshal(data []byte) er
func (ao AtomicQuerySigV2OnChainPubSignals) GetObjMap() map[string]interface{} {
return toMap(ao)
}

func (ao AtomicQuerySigV2OnChainPubSignals) GetStatesInfo() (StatesInfo, error) {
if ao.UserID == nil || ao.IssuerID == nil {
return StatesInfo{}, errors.New(ErrorEmptyID)
}

if ao.IssuerAuthState == nil || ao.IssuerClaimNonRevState == nil ||
ao.GlobalRoot == nil {
return StatesInfo{}, errors.New(ErrorEmptyStateHash)
}

states := []State{
{
ID: *ao.IssuerID,
State: *ao.IssuerAuthState,
},
}
if *ao.IssuerClaimNonRevState != *ao.IssuerAuthState {
states = append(states, State{
ID: *ao.IssuerID,
State: *ao.IssuerClaimNonRevState,
})
}

return StatesInfo{
States: states,
Gists: []Gist{{ID: *ao.UserID, Root: *ao.GlobalRoot}},
}, nil
}
20 changes: 20 additions & 0 deletions credentialAtomicQuerySigV2OnChain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,24 @@ func TestAtomicQuerySigV2OnChainOutputs_CircuitUnmarshal(t *testing.T) {
require.NoError(t, err)

require.JSONEq(t, string(jsonExp), string(jsonOut))

statesInfo, err := exp.GetStatesInfo()
require.NoError(t, err)
wantStatesInfo := StatesInfo{
States: []State{
{
ID: idFromInt("27918766665310231445021466320959318414450284884582375163563581940319453185"),
State: hashFromInt("20177832565449474772630743317224985532862797657496372535616634430055981993180"),
},
},
Gists: []Gist{
{
ID: idFromInt("26109404700696283154998654512117952420503675471097392618762221546565140481"),
Root: hashFromInt("11098939821764568131087645431296528907277253709936443029379587475821759259406"),
},
},
}
j, err := json.Marshal(statesInfo)
require.NoError(t, err)
require.Equal(t, wantStatesInfo, statesInfo, string(j))
}
38 changes: 37 additions & 1 deletion credentialAtomicQuerySigV2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

it "github.com/iden3/go-circuits/v2/testing"
core "github.com/iden3/go-iden3-core/v2"
"github.com/iden3/go-merkletree-sql/v2"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -237,9 +238,44 @@ func TestAtomicQuerySigOutputs_CircuitUnmarshal(t *testing.T) {
require.NoError(t, err)

require.JSONEq(t, string(jsonExp), string(jsonOut))

statesInfo, err := exp.GetStatesInfo()
require.NoError(t, err)
wantStatesInfo := StatesInfo{
States: []State{
{
ID: idFromInt("21933750065545691586450392143787330185992517860945727248803138245838110721"),
State: hashFromInt("2943483356559152311923412925436024635269538717812859789851139200242297094"),
},
},
Gists: []Gist{},
}
j, err := json.Marshal(statesInfo)
require.NoError(t, err)
require.Equal(t, wantStatesInfo, statesInfo, string(j))
}

func idFromInt(i string) core.ID {
bi, ok := new(big.Int).SetString(i, 10)
if !ok {
panic("can't parse int")
}
id, err := core.IDFromInt(bi)
if err != nil {
panic(err)
}
return id
}

func hashFromInt(i string) merkletree.Hash {
h, err := merkletree.NewHashFromString(i)
if err != nil {
panic(err)
}
return *h
}

func hashFromInt(i *big.Int) *merkletree.Hash {
func hashPtrFromInt(i *big.Int) *merkletree.Hash {
h, err := merkletree.NewHashFromBigInt(i)
if err != nil {
panic(err)
Expand Down
25 changes: 25 additions & 0 deletions credentialAtomicQueryV3.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,28 @@ func (ao *AtomicQueryV3PubSignals) PubSignalsUnmarshal(data []byte) error {
func (ao AtomicQueryV3PubSignals) GetObjMap() map[string]interface{} {
return toMap(ao)
}

func (ao AtomicQueryV3PubSignals) GetStatesInfo() (StatesInfo, error) {
if ao.IssuerID == nil {
return StatesInfo{}, errors.New(ErrorEmptyID)
}

if ao.IssuerState == nil || ao.IssuerClaimNonRevState == nil {
return StatesInfo{}, errors.New(ErrorEmptyStateHash)
}

states := []State{
{
ID: *ao.IssuerID,
State: *ao.IssuerState,
},
}
if *ao.IssuerClaimNonRevState != *ao.IssuerState {
states = append(states, State{
ID: *ao.IssuerID,
State: *ao.IssuerClaimNonRevState,
})
}

return StatesInfo{States: states, Gists: []Gist{}}, nil
}
Loading

0 comments on commit a9cdd15

Please sign in to comment.