diff --git a/authV2.go b/authV2.go index d36bc7e..5007aad 100644 --- a/authV2.go +++ b/authV2.go @@ -135,23 +135,35 @@ func (a AuthV2Inputs) InputsMarshal() ([]byte, error) { return json.Marshal(s) } -// AuthV2PubSignals auth.circom public signals -type AuthV2PubSignals struct { - UserID *core.ID `json:"userID"` - Challenge *big.Int `json:"challenge"` - GISTRoot *merkletree.Hash `json:"GISTRoot"` -} +// GetPublicStatesInfo returns states and gists information, +// implements PublicStatesInfoProvider interface +func (a AuthV2Inputs) GetPublicStatesInfo() (StatesInfo, error) { -func (ao *AuthV2PubSignals) GetStatesInfo() StatesInfo { + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + userID, err := core.ProfileID(*a.GenesisID, a.ProfileNonce) + if err != nil { + return StatesInfo{}, err + } return StatesInfo{ States: []State{}, Gists: []Gist{ { - ID: ao.UserID, - Root: ao.GISTRoot, + ID: userID, + Root: *a.GISTProof.Root, }, }, - } + }, nil +} + + +// AuthV2PubSignals auth.circom public signals +type AuthV2PubSignals struct { + UserID *core.ID `json:"userID"` + Challenge *big.Int `json:"challenge"` + GISTRoot *merkletree.Hash `json:"GISTRoot"` } // PubSignalsUnmarshal unmarshal auth.circom public inputs to AuthPubSignals @@ -186,3 +198,16 @@ func (a *AuthV2PubSignals) PubSignalsUnmarshal(data []byte) error { func (a AuthV2PubSignals) GetObjMap() map[string]interface{} { return toMap(a) } + +func (a AuthV2PubSignals) GetStatesInfo() (StatesInfo, error) { + if a.UserID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + if a.GISTRoot == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + return StatesInfo{ + States: []State{}, + Gists: []Gist{{ID: *a.UserID, Root: *a.GISTRoot}}, + }, nil +} diff --git a/authV2_json.go b/authV2_json.go index 6e03a0b..458236b 100644 --- a/authV2_json.go +++ b/authV2_json.go @@ -34,6 +34,13 @@ func (j *jsonInt) MarshalJSON() ([]byte, error) { return json.Marshal((*big.Int)(j).String()) } +func (j *jsonInt) BigInt() *big.Int { + if j == nil { + return nil + } + return (*big.Int)(j) +} + type jsonSignature babyjub.Signature func (s *jsonSignature) UnmarshalJSON(bytes []byte) error { diff --git a/authV2_test.go b/authV2_test.go index 3683494..3d17754 100644 --- a/authV2_test.go +++ b/authV2_test.go @@ -9,12 +9,10 @@ import ( it "github.com/iden3/go-circuits/v2/testing" core "github.com/iden3/go-iden3-core/v2" "github.com/iden3/go-merkletree-sql/v2" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestAuthV2Inputs_InputsMarshal(t *testing.T) { - +func authV2Inputs(t testing.TB) AuthV2Inputs { ctx := context.Background() challenge := big.NewInt(10) @@ -42,7 +40,7 @@ func TestAuthV2Inputs_InputsMarshal(t *testing.T) { signature, err := user.SignBBJJ(challenge.Bytes()) require.NoError(t, err) - inputs := AuthV2Inputs{ + return AuthV2Inputs{ GenesisID: &user.ID, ProfileNonce: nonce, AuthClaim: user.AuthClaim, @@ -56,15 +54,39 @@ func TestAuthV2Inputs_InputsMarshal(t *testing.T) { Signature: signature, Challenge: challenge, } +} +func TestAuthV2Inputs_InputsMarshal(t *testing.T) { + inputs := authV2Inputs(t) circuitInputJSON, err := inputs.InputsMarshal() - assert.Nil(t, err) + require.NoError(t, err) //t.Log(string(circuitInputJSON)) exp := it.TestData(t, "authV2_inputs", string(circuitInputJSON), *generate) require.JSONEq(t, exp, string(circuitInputJSON)) } +func TestAuthV2Inputs_GetPublicStatesInfo(t *testing.T) { + inputs := authV2Inputs(t) + statesInfo, err := inputs.GetPublicStatesInfo() + require.NoError(t, err) + + statesInfoJsonBytes, err := json.Marshal(statesInfo) + require.NoError(t, err) + + want := `{ + "states":[], + "gists":[ + { + "id":"26109404700696283154998654512117952420503675471097392618762221546565140481", + "root":"11098939821764568131087645431296528907277253709936443029379587475821759259406" + } + ] +}` + + require.JSONEq(t, want, string(statesInfoJsonBytes)) +} + func TestAuthV2Inputs_InputsMarshal_fromJson(t *testing.T) { t.Skip("skipping TODO: finish test") auth2_json := `{ @@ -138,32 +160,47 @@ func TestAuthV2Circuit_CircuitUnmarshal(t *testing.T) { // generate mock Data. intID, b := new(big.Int).SetString("19224224881555258540966250468059781351205177043309252290095510834143232000", 10) - assert.True(t, b) + require.True(t, b) identifier, err := core.IDFromInt(intID) - assert.Nil(t, err) + require.NoError(t, err) challenge := big.NewInt(1) stateInt, b := new(big.Int).SetString( "18656147546666944484453899241916469544090258810192803949522794490493271005313", 10) - assert.True(t, b) + require.True(t, b) state, err := merkletree.NewHashFromBigInt(stateInt) - assert.NoError(t, err) + require.NoError(t, err) out := []string{identifier.BigInt().String(), challenge.String(), state.BigInt().String()} bytesOut, err := json.Marshal(out) - assert.NoError(t, err) + require.NoError(t, err) ao := AuthV2PubSignals{} err = ao.PubSignalsUnmarshal(bytesOut) - assert.NoError(t, err) - assert.Equal(t, challenge, ao.Challenge) - assert.Equal(t, state, ao.GISTRoot) - assert.Equal(t, &identifier, ao.UserID) + require.NoError(t, err) + require.Equal(t, challenge, ao.Challenge) + require.Equal(t, state, ao.GISTRoot) + require.Equal(t, &identifier, ao.UserID) + + statesInfo, err := ao.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{}, + Gists: []Gist{ + { + ID: idFromInt("19224224881555258540966250468059781351205177043309252290095510834143232000"), + Root: hashFromInt("18656147546666944484453899241916469544090258810192803949522794490493271005313"), + }, + }, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) } -func GetTreeState(t *testing.T, it *it.IdentityTest) TreeState { +func GetTreeState(t testing.TB, it *it.IdentityTest) TreeState { return TreeState{ State: it.State(t), ClaimsRoot: it.Clt.Root(), diff --git a/circuits.go b/circuits.go index 488d447..15a12bd 100644 --- a/circuits.go +++ b/circuits.go @@ -1,6 +1,8 @@ package circuits import ( + "bytes" + "encoding/json" "reflect" "sync" @@ -198,27 +200,110 @@ type PubSignals interface { PubSignalsMapper } -// StateInfoPubSignals interface implemented by types that can return states info -type StateInfoPubSignals interface { - GetStatesInfo() StatesInfo +// PublicStatesInfoProvider interface should be implemented by Inputs types that +// can return public states info +type PublicStatesInfoProvider interface { + GetPublicStatesInfo() (StatesInfo, error) +} + +// StatesInfoProvider interface should be implemented by PubSignals types that +// can return states info +type StatesInfoProvider interface { + GetStatesInfo() (StatesInfo, error) } // StatesInfo struct. A collection of states and gists type StatesInfo struct { - States []State - Gists []Gist + States []State `json:"states"` + Gists []Gist `json:"gists"` } // State information type State struct { - ID *core.ID - State *merkletree.Hash + ID core.ID `json:"id"` + State merkletree.Hash `json:"state"` +} + +func (s *State) UnmarshalJSON(i []byte) error { + var j struct { + ID *jsonInt `json:"id"` + State *jsonInt `json:"state"` + } + err := json.Unmarshal(i, &j) + if err != nil { + return err + } + + if j.ID == nil { + return errors.New("id is nil") + } + s.ID, err = core.IDFromInt(j.ID.BigInt()) + if err != nil { + return err + } + + h, err := merkletree.NewHashFromBigInt(j.State.BigInt()) + if err != nil { + return err + } + s.State = *h + + return nil +} + +func (s State) MarshalJSON() ([]byte, error) { + var b bytes.Buffer + b.Grow(256) // 20 + 2*78 + padding ≈ 256 + b.Write([]byte(`{"id":"`)) + b.Write([]byte(s.ID.BigInt().String())) + b.Write([]byte(`","state":"`)) + b.Write([]byte(s.State.BigInt().String())) + b.Write([]byte(`"}`)) + return b.Bytes(), nil } // Gist information type Gist struct { - ID *core.ID - Root *merkletree.Hash + ID core.ID `json:"id"` + Root merkletree.Hash `json:"root"` +} + +func (g *Gist) UnmarshalJSON(i []byte) error { + var j struct { + ID *jsonInt `json:"id"` + Root *jsonInt `json:"root"` + } + err := json.Unmarshal(i, &j) + if err != nil { + return err + } + + if j.ID == nil { + return errors.New("id is nil") + } + g.ID, err = core.IDFromInt(j.ID.BigInt()) + if err != nil { + return err + } + + h, err := merkletree.NewHashFromBigInt(j.Root.BigInt()) + if err != nil { + return err + } + g.Root = *h + + return nil +} + +func (g Gist) MarshalJSON() ([]byte, error) { + var b bytes.Buffer + b.Grow(256) // 20 + 2*78 + padding ≈ 256 + b.Write([]byte(`{"id":"`)) + b.Write([]byte(g.ID.BigInt().String())) + b.Write([]byte(`","root":"`)) + b.Write([]byte(g.Root.BigInt().String())) + b.Write([]byte(`"}`)) + return b.Bytes(), nil } // KeyLoader interface, if key should be fetched from file system, CDN, IPFS etc, diff --git a/circuits_test.go b/circuits_test.go index 987ee3f..8530a06 100644 --- a/circuits_test.go +++ b/circuits_test.go @@ -8,6 +8,7 @@ import ( core "github.com/iden3/go-iden3-core/v2" "github.com/iden3/go-merkletree-sql/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshalCircuitOutput(t *testing.T) { @@ -39,3 +40,55 @@ func TestUnmarshalCircuitOutput_Err(t *testing.T) { assert.Equal(t, err, ErrorCircuitIDNotFound) } + +func TestGistJsonMarshallers(t *testing.T) { + var in Gist + var err error + in.ID, err = core.IDFromString("tQomzpDTB6x4EJUaiwk153FVi96jeNfP9WjKp9xys") + require.NoError(t, err) + + h, err := merkletree.NewHashFromString("11098939821764568131087645431296528907277253709936443029379587475821759259406") + require.NoError(t, err) + in.Root = *h + + wantJson := `{ + "id": "26109404700696283154998654512117952420503675471097392618762221546565140481", + "root": "11098939821764568131087645431296528907277253709936443029379587475821759259406" +}` + + inJsonBytes, err := json.Marshal(in) + require.NoError(t, err) + + require.JSONEq(t, wantJson, string(inJsonBytes)) + + var out Gist + err = json.Unmarshal(inJsonBytes, &out) + require.NoError(t, err) + require.Equal(t, in, out) +} + +func TestStateJsonMarshallers(t *testing.T) { + var in State + var err error + in.ID, err = core.IDFromString("tQomzpDTB6x4EJUaiwk153FVi96jeNfP9WjKp9xys") + require.NoError(t, err) + + h, err := merkletree.NewHashFromString("11098939821764568131087645431296528907277253709936443029379587475821759259406") + require.NoError(t, err) + in.State = *h + + wantJson := `{ + "id": "26109404700696283154998654512117952420503675471097392618762221546565140481", + "state": "11098939821764568131087645431296528907277253709936443029379587475821759259406" +}` + + inJsonBytes, err := json.Marshal(in) + require.NoError(t, err) + + require.JSONEq(t, wantJson, string(inJsonBytes)) + + var out State + err = json.Unmarshal(inJsonBytes, &out) + require.NoError(t, err) + require.Equal(t, in, out) +} diff --git a/credentialAtomicQueryMTPV2.go b/credentialAtomicQueryMTPV2.go index c8e636e..bdd4074 100644 --- a/credentialAtomicQueryMTPV2.go +++ b/credentialAtomicQueryMTPV2.go @@ -164,6 +164,46 @@ func (a AtomicQueryMTPV2Inputs) InputsMarshal() ([]byte, error) { return json.Marshal(s) } +func (a AtomicQueryMTPV2Inputs) GetPublicStatesInfo() (StatesInfo, error) { + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + if a.Claim.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyClaim) + } + issuerID := *a.Claim.IssuerID + + if a.Claim.IncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + if a.Claim.NonRevProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + statesInfo := StatesInfo{ + States: []State{ + { + ID: issuerID, + State: *a.Claim.IncProof.TreeState.State, + }, + }, + Gists: []Gist{}, + } + + nonRevProofState := *a.Claim.NonRevProof.TreeState.State + + if statesInfo.States[0].State != nonRevProofState { + statesInfo.States = append(statesInfo.States, State{ + ID: issuerID, + State: nonRevProofState, + }) + } + + return statesInfo, nil +} + // AtomicQueryMTPV2PubSignals public signals type AtomicQueryMTPV2PubSignals struct { BaseConfig @@ -319,3 +359,28 @@ func (ao *AtomicQueryMTPV2PubSignals) PubSignalsUnmarshal(data []byte) error { func (ao AtomicQueryMTPV2PubSignals) GetObjMap() map[string]interface{} { return toMap(ao) } + +func (ao AtomicQueryMTPV2PubSignals) GetStatesInfo() (StatesInfo, error) { + if ao.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + + if ao.IssuerClaimIdenState == nil || ao.IssuerClaimNonRevState == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + states := []State{ + { + ID: *ao.IssuerID, + State: *ao.IssuerClaimIdenState, + }, + } + if *ao.IssuerClaimNonRevState != *ao.IssuerClaimIdenState { + states = append(states, State{ + ID: *ao.IssuerID, + State: *ao.IssuerClaimNonRevState, + }) + } + + return StatesInfo{States: states, Gists: []Gist{}}, nil +} diff --git a/credentialAtomicQueryMTPV2OnChain.go b/credentialAtomicQueryMTPV2OnChain.go index 5e74c8a..6fe2a68 100644 --- a/credentialAtomicQueryMTPV2OnChain.go +++ b/credentialAtomicQueryMTPV2OnChain.go @@ -251,6 +251,60 @@ func (a AtomicQueryMTPV2OnChainInputs) InputsMarshal() ([]byte, error) { return json.Marshal(s) } +func (a AtomicQueryMTPV2OnChainInputs) GetPublicStatesInfo() (StatesInfo, error) { + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + if a.Claim.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyClaim) + } + issuerID := *a.Claim.IssuerID + + if a.Claim.IncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + if a.Claim.NonRevProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + userID, err := core.ProfileID(*a.ID, a.ProfileNonce) + if err != nil { + return StatesInfo{}, err + } + + if a.GISTProof.Root == nil { + return StatesInfo{}, errors.New(ErrorEmptyGISTProof) + } + + statesInfo := StatesInfo{ + States: []State{ + { + ID: issuerID, + State: *a.Claim.IncProof.TreeState.State, + }, + }, + Gists: []Gist{ + { + ID: userID, + Root: *a.GISTProof.Root, + }, + }, + } + + nonRevProofState := *a.Claim.NonRevProof.TreeState.State + + if statesInfo.States[0].State != nonRevProofState { + statesInfo.States = append(statesInfo.States, State{ + ID: issuerID, + State: nonRevProofState, + }) + } + + return statesInfo, nil +} + // AtomicQueryMTPPubSignals public signals type AtomicQueryMTPV2OnChainPubSignals struct { BaseConfig @@ -267,27 +321,6 @@ type AtomicQueryMTPV2OnChainPubSignals struct { GlobalRoot *merkletree.Hash `json:"gistRoot"` } -func (ao *AtomicQueryMTPV2OnChainPubSignals) GetStatesInfo() StatesInfo { - return StatesInfo{ - States: []State{ - { - ID: ao.IssuerID, - State: ao.IssuerClaimIdenState, - }, - { - ID: ao.IssuerID, - State: ao.IssuerClaimNonRevState, - }, - }, - Gists: []Gist{ - { - ID: ao.UserID, - Root: ao.GlobalRoot, - }, - }, - } -} - // PubSignalsUnmarshal unmarshal credentialAtomicQueryMTPV2OnChain.circom public signals array to AtomicQueryMTPPubSignals func (ao *AtomicQueryMTPV2OnChainPubSignals) PubSignalsUnmarshal(data []byte) error { @@ -387,3 +420,32 @@ func (ao *AtomicQueryMTPV2OnChainPubSignals) PubSignalsUnmarshal(data []byte) er func (ao AtomicQueryMTPV2OnChainPubSignals) GetObjMap() map[string]interface{} { return toMap(ao) } + +func (ao AtomicQueryMTPV2OnChainPubSignals) GetStatesInfo() (StatesInfo, error) { + if ao.IssuerID == nil || ao.UserID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + + if ao.IssuerClaimIdenState == nil || ao.IssuerClaimNonRevState == nil || + ao.GlobalRoot == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + states := []State{ + { + ID: *ao.IssuerID, + State: *ao.IssuerClaimIdenState, + }, + } + if *ao.IssuerClaimNonRevState != *ao.IssuerClaimIdenState { + states = append(states, State{ + ID: *ao.IssuerID, + State: *ao.IssuerClaimNonRevState, + }) + } + + return StatesInfo{ + States: states, + Gists: []Gist{{ID: *ao.UserID, Root: *ao.GlobalRoot}}, + }, nil +} diff --git a/credentialAtomicQueryMTPV2OnChain_test.go b/credentialAtomicQueryMTPV2OnChain_test.go index 68a74a7..3df4668 100644 --- a/credentialAtomicQueryMTPV2OnChain_test.go +++ b/credentialAtomicQueryMTPV2OnChain_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAttrQueryMTPV2OnChain_PrepareInputs(t *testing.T) { +func queryMTPV2OnChainInputs(t testing.TB) AtomicQueryMTPV2OnChainInputs { challenge := big.NewInt(10) // generate identity @@ -50,7 +50,7 @@ func TestAttrQueryMTPV2OnChain_PrepareInputs(t *testing.T) { issuerClaimNonRevMtp, _ := issuer.ClaimRevMTPRaw(t, claim) - in := AtomicQueryMTPV2OnChainInputs{ + return AtomicQueryMTPV2OnChainInputs{ RequestID: big.NewInt(23), ID: &user.ID, ProfileNonce: nonce, @@ -95,7 +95,10 @@ func TestAttrQueryMTPV2OnChain_PrepareInputs(t *testing.T) { Signature: signature, Challenge: challenge, } +} +func TestAttrQueryMTPV2OnChain_PrepareInputs(t *testing.T) { + in := queryMTPV2OnChainInputs(t) bytesInputs, err := in.InputsMarshal() require.Nil(t, err) exp := it.TestData(t, "mtpV2OnChain_inputs", string(bytesInputs), *generate) @@ -104,6 +107,31 @@ func TestAttrQueryMTPV2OnChain_PrepareInputs(t *testing.T) { } +func TestAttrQueryMTPV2OnChain_GetPublicStatesInfo(t *testing.T) { + in := queryMTPV2OnChainInputs(t) + statesInfo, err := in.GetPublicStatesInfo() + require.NoError(t, err) + + bs, err := json.Marshal(statesInfo) + require.NoError(t, err) + + wantStatesInfo := `{ + "states": [ + { + "id": "27918766665310231445021466320959318414450284884582375163563581940319453185", + "state": "19157496396839393206871475267813888069926627705277243727237933406423274512449" + } + ], + "gists": [ + { + "id": "26109404700696283154998654512117952420503675471097392618762221546565140481", + "root": "11098939821764568131087645431296528907277253709936443029379587475821759259406" + } + ] +}` + require.JSONEq(t, wantStatesInfo, string(bs)) +} + func TestAtomicQueryMTPVOnChain2Outputs_CircuitUnmarshal(t *testing.T) { out := new(AtomicQueryMTPV2OnChainPubSignals) err := out.PubSignalsUnmarshal([]byte(`[ @@ -162,4 +190,24 @@ func TestAtomicQueryMTPVOnChain2Outputs_CircuitUnmarshal(t *testing.T) { require.NoError(t, err) require.JSONEq(t, string(jsonExp), string(jsonOut)) + + statesInfo, err := exp.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{ + { + ID: idFromInt("27918766665310231445021466320959318414450284884582375163563581940319453185"), + State: hashFromInt("19157496396839393206871475267813888069926627705277243727237933406423274512449"), + }, + }, + Gists: []Gist{ + { + ID: idFromInt("26109404700696283154998654512117952420503675471097392618762221546565140481"), + Root: hashFromInt("11098939821764568131087645431296528907277253709936443029379587475821759259406"), + }, + }, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) } diff --git a/credentialAtomicQueryMTPV2_test.go b/credentialAtomicQueryMTPV2_test.go index 1c1eb72..3242616 100644 --- a/credentialAtomicQueryMTPV2_test.go +++ b/credentialAtomicQueryMTPV2_test.go @@ -9,8 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAttrQueryMTPV2_PrepareInputs(t *testing.T) { - +func queryMTPV2Inputs(t testing.TB) AtomicQueryMTPV2Inputs { user := it.NewIdentity(t, userPK) issuer := it.NewIdentity(t, issuerPK) @@ -27,7 +26,7 @@ func TestAttrQueryMTPV2_PrepareInputs(t *testing.T) { issuerClaimNonRevMtp, _ := issuer.ClaimRevMTPRaw(t, claim) - in := AtomicQueryMTPV2Inputs{ + return AtomicQueryMTPV2Inputs{ RequestID: big.NewInt(23), ID: &user.ID, ProfileNonce: nonce, @@ -62,14 +61,36 @@ func TestAttrQueryMTPV2_PrepareInputs(t *testing.T) { }, CurrentTimeStamp: timestamp, } +} +func TestAttrQueryMTPV2_PrepareInputs(t *testing.T) { + in := queryMTPV2Inputs(t) bytesInputs, err := in.InputsMarshal() require.Nil(t, err) exp := it.TestData(t, "mtpV2_inputs", string(bytesInputs), *generate) t.Log(string(bytesInputs)) require.JSONEq(t, exp, string(bytesInputs)) +} + +func TestAttrQueryMTPV2_GetPublicStatesInfo(t *testing.T) { + in := queryMTPV2Inputs(t) + statesInfo, err := in.GetPublicStatesInfo() + require.NoError(t, err) + bs, err := json.Marshal(statesInfo) + require.NoError(t, err) + + wantStatesInfo := `{ + "states": [ + { + "id": "27918766665310231445021466320959318414450284884582375163563581940319453185", + "state": "19157496396839393206871475267813888069926627705277243727237933406423274512449" + } + ], + "gists": [] +}` + require.JSONEq(t, wantStatesInfo, string(bs)) } func TestAtomicQueryMTPV2Outputs_CircuitUnmarshal(t *testing.T) { @@ -184,4 +205,19 @@ func TestAtomicQueryMTPV2Outputs_CircuitUnmarshal(t *testing.T) { require.NoError(t, err) require.JSONEq(t, string(jsonExp), string(jsonOut)) + + statesInfo, err := exp.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{ + { + ID: idFromInt("23528770672049181535970744460798517976688641688582489375761566420828291073"), + State: hashFromInt("5687720250943511874245715094520098014548846873346473635855112185560372332782"), + }, + }, + Gists: []Gist{}, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) } diff --git a/credentialAtomicQuerySigV2.go b/credentialAtomicQuerySigV2.go index 50cad6b..6da5ea6 100644 --- a/credentialAtomicQuerySigV2.go +++ b/credentialAtomicQuerySigV2.go @@ -206,6 +206,43 @@ func (a AtomicQuerySigV2Inputs) InputsMarshal() ([]byte, error) { return json.Marshal(s) } +func (a AtomicQuerySigV2Inputs) GetPublicStatesInfo() (StatesInfo, error) { + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + issuerID := a.Claim.IssuerID + var issuerState merkletree.Hash + if a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + issuerState = *a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State + + if a.Claim.NonRevProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + statesInfo := StatesInfo{ + States: []State{ + { + ID: *issuerID, + State: issuerState, + }, + }, + Gists: []Gist{}, + } + + nonRevProofState := *a.Claim.NonRevProof.TreeState.State + if issuerState != nonRevProofState { + statesInfo.States = append(statesInfo.States, State{ + ID: *issuerID, + State: nonRevProofState, + }) + } + + return statesInfo, nil +} + // AtomicQuerySigV2PubSignals public inputs type AtomicQuerySigV2PubSignals struct { BaseConfig @@ -360,3 +397,28 @@ func (ao *AtomicQuerySigV2PubSignals) PubSignalsUnmarshal(data []byte) error { func (ao AtomicQuerySigV2PubSignals) GetObjMap() map[string]interface{} { return toMap(ao) } + +func (ao AtomicQuerySigV2PubSignals) GetStatesInfo() (StatesInfo, error) { + if ao.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + + if ao.IssuerAuthState == nil || ao.IssuerClaimNonRevState == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + states := []State{ + { + ID: *ao.IssuerID, + State: *ao.IssuerAuthState, + }, + } + if *ao.IssuerClaimNonRevState != *ao.IssuerAuthState { + states = append(states, State{ + ID: *ao.IssuerID, + State: *ao.IssuerClaimNonRevState, + }) + } + + return StatesInfo{States: states, Gists: []Gist{}}, nil +} diff --git a/credentialAtomicQuerySigV2OnChain.go b/credentialAtomicQuerySigV2OnChain.go index a8e53af..19d2892 100644 --- a/credentialAtomicQuerySigV2OnChain.go +++ b/credentialAtomicQuerySigV2OnChain.go @@ -294,6 +294,58 @@ func (a AtomicQuerySigV2OnChainInputs) InputsMarshal() ([]byte, error) { return json.Marshal(s) } +func (a AtomicQuerySigV2OnChainInputs) GetPublicStatesInfo() (StatesInfo, error) { + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + issuerID := a.Claim.IssuerID + var issuerState merkletree.Hash + if a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + issuerState = *a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State + + if a.Claim.NonRevProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + userID, err := core.ProfileID(*a.ID, a.ProfileNonce) + if err != nil { + return StatesInfo{}, err + } + + if a.GISTProof.Root == nil { + return StatesInfo{}, errors.New(ErrorEmptyGISTProof) + } + + statesInfo := StatesInfo{ + States: []State{ + { + ID: *issuerID, + State: issuerState, + }, + }, + Gists: []Gist{ + { + ID: userID, + Root: *a.GISTProof.Root, + }, + }, + } + + nonRevProofState := *a.Claim.NonRevProof.TreeState.State + if issuerState != nonRevProofState { + statesInfo.States = append(statesInfo.States, State{ + ID: *issuerID, + State: nonRevProofState, + }) + } + + return statesInfo, nil +} + + // AtomicQuerySigV2OnChainPubSignals public inputs type AtomicQuerySigV2OnChainPubSignals struct { BaseConfig @@ -310,27 +362,6 @@ type AtomicQuerySigV2OnChainPubSignals struct { GlobalRoot *merkletree.Hash `json:"gistRoot"` } -func (ao *AtomicQuerySigV2OnChainPubSignals) GetStatesInfo() StatesInfo { - return StatesInfo{ - States: []State{ - { - ID: ao.IssuerID, - State: ao.IssuerAuthState, - }, - { - ID: ao.IssuerID, - State: ao.IssuerClaimNonRevState, - }, - }, - Gists: []Gist{ - { - ID: ao.UserID, - Root: ao.GlobalRoot, - }, - }, - } -} - // PubSignalsUnmarshal unmarshal credentialAtomicQuerySig.circom public signals func (ao *AtomicQuerySigV2OnChainPubSignals) PubSignalsUnmarshal(data []byte) error { // expected order: @@ -431,3 +462,32 @@ func (ao *AtomicQuerySigV2OnChainPubSignals) PubSignalsUnmarshal(data []byte) er func (ao AtomicQuerySigV2OnChainPubSignals) GetObjMap() map[string]interface{} { return toMap(ao) } + +func (ao AtomicQuerySigV2OnChainPubSignals) GetStatesInfo() (StatesInfo, error) { + if ao.UserID == nil || ao.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + + if ao.IssuerAuthState == nil || ao.IssuerClaimNonRevState == nil || + ao.GlobalRoot == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + states := []State{ + { + ID: *ao.IssuerID, + State: *ao.IssuerAuthState, + }, + } + if *ao.IssuerClaimNonRevState != *ao.IssuerAuthState { + states = append(states, State{ + ID: *ao.IssuerID, + State: *ao.IssuerClaimNonRevState, + }) + } + + return StatesInfo{ + States: states, + Gists: []Gist{{ID: *ao.UserID, Root: *ao.GlobalRoot}}, + }, nil +} diff --git a/credentialAtomicQuerySigV2OnChain_test.go b/credentialAtomicQuerySigV2OnChain_test.go index f333d76..781d27f 100644 --- a/credentialAtomicQuerySigV2OnChain_test.go +++ b/credentialAtomicQuerySigV2OnChain_test.go @@ -11,8 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAttrQuerySigV2OnChain_PrepareInputs(t *testing.T) { - +func querySigV2OnChainInputs(t testing.TB) AtomicQuerySigV2OnChainInputs { user := it.NewIdentity(t, userPK) issuer := it.NewIdentity(t, issuerPK) @@ -49,7 +48,7 @@ func TestAttrQuerySigV2OnChain_PrepareInputs(t *testing.T) { signature, err := user.SignBBJJ(challenge.Bytes()) require.NoError(t, err) - in := AtomicQuerySigV2OnChainInputs{ + return AtomicQuerySigV2OnChainInputs{ RequestID: big.NewInt(23), ID: &user.ID, ProfileNonce: profileNonce, @@ -107,13 +106,40 @@ func TestAttrQuerySigV2OnChain_PrepareInputs(t *testing.T) { Signature: signature, Challenge: challenge, } +} +func TestAttrQuerySigV2OnChain_PrepareInputs(t *testing.T) { + in := querySigV2OnChainInputs(t) bytesInputs, err := in.InputsMarshal() require.Nil(t, err) exp := it.TestData(t, "sigV2OnChain_inputs", string(bytesInputs), *generate) require.JSONEq(t, exp, string(bytesInputs)) +} + +func TestAttrQuerySigV2OnChain_GetPublicStatesInfo(t *testing.T) { + in := querySigV2OnChainInputs(t) + statesInfo, err := in.GetPublicStatesInfo() + require.NoError(t, err) + + bs, err := json.Marshal(statesInfo) + require.NoError(t, err) + wantStatesInfo := `{ + "states": [ + { + "id": "27918766665310231445021466320959318414450284884582375163563581940319453185", + "state": "20177832565449474772630743317224985532862797657496372535616634430055981993180" + } + ], + "gists": [ + { + "id": "26109404700696283154998654512117952420503675471097392618762221546565140481", + "root": "11098939821764568131087645431296528907277253709936443029379587475821759259406" + } + ] +}` + require.JSONEq(t, wantStatesInfo, string(bs)) } func TestAtomicQuerySigV2OnChainOutputs_CircuitUnmarshal(t *testing.T) { @@ -172,4 +198,24 @@ func TestAtomicQuerySigV2OnChainOutputs_CircuitUnmarshal(t *testing.T) { require.NoError(t, err) require.JSONEq(t, string(jsonExp), string(jsonOut)) + + statesInfo, err := exp.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{ + { + ID: idFromInt("27918766665310231445021466320959318414450284884582375163563581940319453185"), + State: hashFromInt("20177832565449474772630743317224985532862797657496372535616634430055981993180"), + }, + }, + Gists: []Gist{ + { + ID: idFromInt("26109404700696283154998654512117952420503675471097392618762221546565140481"), + Root: hashFromInt("11098939821764568131087645431296528907277253709936443029379587475821759259406"), + }, + }, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) } diff --git a/credentialAtomicQuerySigV2_test.go b/credentialAtomicQuerySigV2_test.go index 744ecd2..3b180c4 100644 --- a/credentialAtomicQuerySigV2_test.go +++ b/credentialAtomicQuerySigV2_test.go @@ -8,6 +8,7 @@ import ( "testing" it "github.com/iden3/go-circuits/v2/testing" + core "github.com/iden3/go-iden3-core/v2" "github.com/iden3/go-merkletree-sql/v2" "github.com/stretchr/testify/require" ) @@ -27,10 +28,8 @@ const ( timestamp = 1642074362 ) -func TestAttrQuerySigV2_PrepareInputs(t *testing.T) { - +func querySigV2Inputs(t testing.TB) AtomicQuerySigV2Inputs { user := it.NewIdentity(t, userPK) - issuer := it.NewIdentity(t, issuerPK) subjectID := user.ID @@ -48,7 +47,7 @@ func TestAttrQuerySigV2_PrepareInputs(t *testing.T) { issuerAuthClaimNonRevMtp, _ := issuer.ClaimRevMTPRaw(t, issuer.AuthClaim) issuerAuthClaimMtp, _ := issuer.ClaimMTPRaw(t, issuer.AuthClaim) - in := AtomicQuerySigV2Inputs{ + return AtomicQuerySigV2Inputs{ RequestID: big.NewInt(23), ID: &user.ID, ProfileNonce: profileNonce, @@ -96,7 +95,10 @@ func TestAttrQuerySigV2_PrepareInputs(t *testing.T) { }, CurrentTimeStamp: timestamp, } +} +func TestAttrQuerySigV2_PrepareInputs(t *testing.T) { + in := querySigV2Inputs(t) bytesInputs, err := in.InputsMarshal() require.Nil(t, err) @@ -105,6 +107,26 @@ func TestAttrQuerySigV2_PrepareInputs(t *testing.T) { } +func TestAttrQuerySigV2_GetPublicStatesInfo(t *testing.T) { + in := querySigV2Inputs(t) + statesInfo, err := in.GetPublicStatesInfo() + require.NoError(t, err) + + bs, err := json.Marshal(statesInfo) + require.NoError(t, err) + + wantStatesInfo := `{ + "states": [ + { + "id": "27918766665310231445021466320959318414450284884582375163563581940319453185", + "state": "20177832565449474772630743317224985532862797657496372535616634430055981993180" + } + ], + "gists": [] +}` + require.JSONEq(t, wantStatesInfo, string(bs)) +} + func TestAtomicQuerySigOutputs_CircuitUnmarshal(t *testing.T) { out := new(AtomicQuerySigV2PubSignals) err := out.PubSignalsUnmarshal([]byte( @@ -216,9 +238,44 @@ func TestAtomicQuerySigOutputs_CircuitUnmarshal(t *testing.T) { require.NoError(t, err) require.JSONEq(t, string(jsonExp), string(jsonOut)) + + statesInfo, err := exp.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{ + { + ID: idFromInt("21933750065545691586450392143787330185992517860945727248803138245838110721"), + State: hashFromInt("2943483356559152311923412925436024635269538717812859789851139200242297094"), + }, + }, + Gists: []Gist{}, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) +} + +func idFromInt(i string) core.ID { + bi, ok := new(big.Int).SetString(i, 10) + if !ok { + panic("can't parse int") + } + id, err := core.IDFromInt(bi) + if err != nil { + panic(err) + } + return id +} + +func hashFromInt(i string) merkletree.Hash { + h, err := merkletree.NewHashFromString(i) + if err != nil { + panic(err) + } + return *h } -func hashFromInt(i *big.Int) *merkletree.Hash { +func hashPtrFromInt(i *big.Int) *merkletree.Hash { h, err := merkletree.NewHashFromBigInt(i) if err != nil { panic(err) diff --git a/credentialAtomicQueryV3.go b/credentialAtomicQueryV3.go index d309a68..202e36d 100644 --- a/credentialAtomicQueryV3.go +++ b/credentialAtomicQueryV3.go @@ -322,6 +322,60 @@ func (a AtomicQueryV3Inputs) fillSigProofWithZero(s *atomicQueryV3CircuitInputs) s.IssuerAuthState = &merkletree.HashZero } +func (a AtomicQueryV3Inputs) GetPublicStatesInfo() (StatesInfo, error) { + + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + issuerID := a.Claim.IssuerID + var issuerState merkletree.Hash + switch a.ProofType { + case BJJSignatureProofType: + if a.Claim.SignatureProof == nil { + return StatesInfo{}, errors.New(ErrorEmptySignatureProof) + } + if a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + issuerState = *a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State + case Iden3SparseMerkleTreeProofType: + if a.Claim.IncProof == nil { + return StatesInfo{}, errors.New(ErrorEmptyMTPProof) + } + if a.Claim.IncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + issuerState = *a.Claim.IncProof.TreeState.State + default: + return StatesInfo{}, errors.New(ErrorInvalidProofType) + } + + if a.Claim.NonRevProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + statesInfo := StatesInfo{ + States: []State{ + { + ID: *issuerID, + State: issuerState, + }, + }, + Gists: []Gist{}, + } + + nonRefProofState := *a.Claim.NonRevProof.TreeState.State + if issuerState != nonRefProofState { + statesInfo.States = append(statesInfo.States, State{ + ID: *issuerID, + State: nonRefProofState, + }) + } + + return statesInfo, nil +} + // AtomicQueryV3PubSignals public inputs type AtomicQueryV3PubSignals struct { BaseConfig @@ -524,3 +578,28 @@ func (ao *AtomicQueryV3PubSignals) PubSignalsUnmarshal(data []byte) error { func (ao AtomicQueryV3PubSignals) GetObjMap() map[string]interface{} { return toMap(ao) } + +func (ao AtomicQueryV3PubSignals) GetStatesInfo() (StatesInfo, error) { + if ao.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + + if ao.IssuerState == nil || ao.IssuerClaimNonRevState == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + states := []State{ + { + ID: *ao.IssuerID, + State: *ao.IssuerState, + }, + } + if *ao.IssuerClaimNonRevState != *ao.IssuerState { + states = append(states, State{ + ID: *ao.IssuerID, + State: *ao.IssuerClaimNonRevState, + }) + } + + return StatesInfo{States: states, Gists: []Gist{}}, nil +} diff --git a/credentialAtomicQueryV3OnChain.go b/credentialAtomicQueryV3OnChain.go index 8ad3bb2..a017e93 100644 --- a/credentialAtomicQueryV3OnChain.go +++ b/credentialAtomicQueryV3OnChain.go @@ -445,6 +445,74 @@ func (a AtomicQueryV3OnChainInputs) fillSigProofWithZero(s *atomicQueryV3OnChain s.IssuerAuthState = &merkletree.HashZero } +func (a AtomicQueryV3OnChainInputs) GetPublicStatesInfo() (StatesInfo, error) { + + if err := a.Validate(); err != nil { + return StatesInfo{}, err + } + + issuerID := a.Claim.IssuerID + var issuerState merkletree.Hash + switch a.ProofType { + case BJJSignatureProofType: + if a.Claim.SignatureProof == nil { + return StatesInfo{}, errors.New(ErrorEmptySignatureProof) + } + if a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + issuerState = *a.Claim.SignatureProof.IssuerAuthIncProof.TreeState.State + case Iden3SparseMerkleTreeProofType: + if a.Claim.IncProof == nil { + return StatesInfo{}, errors.New(ErrorEmptyMTPProof) + } + if a.Claim.IncProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + issuerState = *a.Claim.IncProof.TreeState.State + default: + return StatesInfo{}, errors.New(ErrorInvalidProofType) + } + + if a.Claim.NonRevProof.TreeState.State == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + userID, err := core.ProfileID(*a.ID, a.ProfileNonce) + if err != nil { + return StatesInfo{}, err + } + + if a.GISTProof.Root == nil { + return StatesInfo{}, errors.New(ErrorEmptyGISTProof) + } + + statesInfo := StatesInfo{ + States: []State{ + { + ID: *issuerID, + State: issuerState, + }, + }, + Gists: []Gist{ + { + ID: userID, + Root: *a.GISTProof.Root, + }, + }, + } + + nonRevProofState := *a.Claim.NonRevProof.TreeState.State + if issuerState != nonRevProofState { + statesInfo.States = append(statesInfo.States, State{ + ID: *issuerID, + State: nonRevProofState, + }) + } + + return statesInfo, nil +} + // AtomicQueryV3OnChainPubSignals public inputs type AtomicQueryV3OnChainPubSignals struct { BaseConfig @@ -464,27 +532,6 @@ type AtomicQueryV3OnChainPubSignals struct { IsBJJAuthEnabled int `json:"isBJJAuthEnabled"` } -func (ao *AtomicQueryV3OnChainPubSignals) GetStatesInfo() StatesInfo { - return StatesInfo{ - States: []State{ - { - ID: ao.IssuerID, - State: ao.IssuerState, - }, - { - ID: ao.IssuerID, - State: ao.IssuerClaimNonRevState, - }, - }, - Gists: []Gist{ - { - ID: ao.UserID, - Root: ao.GlobalRoot, - }, - }, - } -} - // PubSignalsUnmarshal unmarshal credentialAtomicQueryV3OnChain.circom public signals func (ao *AtomicQueryV3OnChainPubSignals) PubSignalsUnmarshal(data []byte) error { // expected order: @@ -603,3 +650,32 @@ func (ao *AtomicQueryV3OnChainPubSignals) PubSignalsUnmarshal(data []byte) error func (ao AtomicQueryV3OnChainPubSignals) GetObjMap() map[string]interface{} { return toMap(ao) } + +func (ao AtomicQueryV3OnChainPubSignals) GetStatesInfo() (StatesInfo, error) { + if ao.UserID == nil || ao.IssuerID == nil { + return StatesInfo{}, errors.New(ErrorEmptyID) + } + + if ao.IssuerState == nil || ao.IssuerClaimNonRevState == nil || + ao.GlobalRoot == nil { + return StatesInfo{}, errors.New(ErrorEmptyStateHash) + } + + states := []State{ + { + ID: *ao.IssuerID, + State: *ao.IssuerState, + }, + } + if *ao.IssuerClaimNonRevState != *ao.IssuerState { + states = append(states, State{ + ID: *ao.IssuerID, + State: *ao.IssuerClaimNonRevState, + }) + } + + return StatesInfo{ + States: states, + Gists: []Gist{{ID: *ao.UserID, Root: *ao.GlobalRoot}}, + }, nil +} diff --git a/credentialAtomicQueryV3OnChain_test.go b/credentialAtomicQueryV3OnChain_test.go index 34fb992..1f5ede0 100644 --- a/credentialAtomicQueryV3OnChain_test.go +++ b/credentialAtomicQueryV3OnChain_test.go @@ -12,8 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAttrQueryV3OnChain_SigPart_PrepareInputs(t *testing.T) { - +func createV3OnChaneInputs_Sig(t testing.TB) AtomicQueryV3OnChainInputs { user := it.NewIdentity(t, userPK) issuer := it.NewIdentity(t, issuerPK) @@ -110,6 +109,12 @@ func TestAttrQueryV3OnChain_SigPart_PrepareInputs(t *testing.T) { IsBJJAuthEnabled: 1, } + return in +} + +func TestAttrQueryV3OnChain_SigPart_PrepareInputs(t *testing.T) { + in := createV3OnChaneInputs_Sig(t) + bytesInputs, err := in.InputsMarshal() require.Nil(t, err) @@ -119,6 +124,32 @@ func TestAttrQueryV3OnChain_SigPart_PrepareInputs(t *testing.T) { require.JSONEq(t, exp, string(bytesInputs)) } +func TestAttrQueryV3OnChain_SigPart_GetPublicStatesInfo(t *testing.T) { + in := createV3OnChaneInputs_Sig(t) + + statesInfo, err := in.GetPublicStatesInfo() + require.NoError(t, err) + + bs, err := json.Marshal(statesInfo) + require.NoError(t, err) + + wantStatesInfo := `{ + "states": [ + { + "id": "27918766665310231445021466320959318414450284884582375163563581940319453185", + "state": "20177832565449474772630743317224985532862797657496372535616634430055981993180" + } + ], + "gists": [ + { + "id": "26109404700696283154998654512117952420503675471097392618762221546565140481", + "root": "11098939821764568131087645431296528907277253709936443029379587475821759259406" + } + ] +}` + require.JSONEq(t, wantStatesInfo, string(bs)) +} + func TestAttrQueryV3OnChain_SigPart_Noop_PrepareInputs(t *testing.T) { user := it.NewIdentity(t, userPK) @@ -385,6 +416,30 @@ func TestAtomicQueryV3OnChainOutputs_Sig_CircuitUnmarshal(t *testing.T) { require.NoError(t, err) require.JSONEq(t, string(jsonExp), string(jsonOut)) + + statesInfo, err := exp.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{ + { + ID: idFromInt("27918766665310231445021466320959318414450284884582375163563581940319453185"), + State: hashFromInt("2943483356559152311923412925436024635269538717812859789851139200242297094"), + }, + { + ID: idFromInt("27918766665310231445021466320959318414450284884582375163563581940319453185"), + State: hashFromInt("20177832565449474772630743317224985532862797657496372535616634430055981993180"), + }, + }, + Gists: []Gist{ + { + ID: idFromInt("26109404700696283154998654512117952420503675471097392618762221546565140481"), + Root: hashFromInt("20177832565449474772630743317224985532862797657496372535616634430055981993180"), + }, + }, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) } func TestAtomicQueryV3OnChainOutputs_MTP_CircuitUnmarshal(t *testing.T) { diff --git a/credentialAtomicQueryV3_test.go b/credentialAtomicQueryV3_test.go index ff6adf7..5accd16 100644 --- a/credentialAtomicQueryV3_test.go +++ b/credentialAtomicQueryV3_test.go @@ -10,8 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAttrQueryV3_SigPart_PrepareInputs(t *testing.T) { - +func createV3Inputs_Sig(t testing.TB) AtomicQueryV3Inputs { user := it.NewIdentity(t, userPK) issuer := it.NewIdentity(t, issuerPK) @@ -31,7 +30,7 @@ func TestAttrQueryV3_SigPart_PrepareInputs(t *testing.T) { issuerAuthClaimNonRevMtp, _ := issuer.ClaimRevMTPRaw(t, issuer.AuthClaim) issuerAuthClaimMtp, _ := issuer.ClaimMTPRaw(t, issuer.AuthClaim) - in := AtomicQueryV3Inputs{ + return AtomicQueryV3Inputs{ RequestID: big.NewInt(23), ID: &user.ID, ProfileNonce: profileNonce, @@ -84,7 +83,10 @@ func TestAttrQueryV3_SigPart_PrepareInputs(t *testing.T) { t, "21929109382993718606847853573861987353620810345503358891473103689157378049"), NullifierSessionID: big.NewInt(32), } +} +func TestAttrQueryV3_SigPart_PrepareInputs(t *testing.T) { + in := createV3Inputs_Sig(t) bytesInputs, err := in.InputsMarshal() require.Nil(t, err) @@ -94,6 +96,27 @@ func TestAttrQueryV3_SigPart_PrepareInputs(t *testing.T) { require.JSONEq(t, exp, string(bytesInputs)) } +func TestAttrQueryV3_SigPart_GetPublicStateInfo(t *testing.T) { + in := createV3Inputs_Sig(t) + statesInfo, err := in.GetPublicStatesInfo() + require.NoError(t, err) + + statesInfoJsonBytes, err := json.Marshal(statesInfo) + require.NoError(t, err) + + wantStatesInfoJson := `{ + "states":[ + { + "id":"27918766665310231445021466320959318414450284884582375163563581940319453185", + "state":"20177832565449474772630743317224985532862797657496372535616634430055981993180" + } + ], + "gists":[] +}` + + require.JSONEq(t, wantStatesInfoJson, string(statesInfoJsonBytes)) +} + func TestAttrQueryV3_MTPPart_PrepareInputs(t *testing.T) { user := it.NewIdentity(t, userPK) @@ -356,6 +379,21 @@ func TestAtomicQueryV3Outputs_Sig_CircuitUnmarshal(t *testing.T) { require.NoError(t, err) require.JSONEq(t, string(jsonExp), string(jsonOut)) + + statesInfo, err := exp.GetStatesInfo() + require.NoError(t, err) + wantStatesInfo := StatesInfo{ + States: []State{ + { + ID: idFromInt("21933750065545691586450392143787330185992517860945727248803138245838110721"), + State: hashFromInt("2943483356559152311923412925436024635269538717812859789851139200242297094"), + }, + }, + Gists: []Gist{}, + } + j, err := json.Marshal(statesInfo) + require.NoError(t, err) + require.Equal(t, wantStatesInfo, statesInfo, string(j)) } func TestAtomicQueryV3Outputs_MTP_CircuitUnmarshal(t *testing.T) { diff --git a/errors.go b/errors.go index 5f5ee78..92c04bd 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,5 @@ const ( ErrorEmptySignatureProof = "empty signature proof" ErrorEmptyMTPProof = "empty MTP proof" ErrorInvalidValuesArrSize = "invalid query Values array size" + ErrorEmptyStateHash = "empty state hash" ) diff --git a/stateTransition_test.go b/stateTransition_test.go index b32b420..e6c1a96 100644 --- a/stateTransition_test.go +++ b/stateTransition_test.go @@ -15,8 +15,8 @@ func TestStateTransitionOutput_GetJSONObj(t *testing.T) { id, err := core.IDFromString("1124NoAu14diR5EM1kgUha2uHFkvUrPrTXMtf4tncZ") assert.Nil(t, err) - newState := hashFromInt(big.NewInt(1)) - oldState := hashFromInt(big.NewInt(2)) + newState := hashPtrFromInt(big.NewInt(1)) + oldState := hashPtrFromInt(big.NewInt(2)) sto := StateTransitionPubSignals{ UserID: &id,