Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fixed parameters in GeneralEncoder #252

Merged
merged 3 commits into from
Nov 18, 2024

Conversation

nikhilkhatri
Copy link

This PR adds the ability to fix parameters of some parameterised gates in a GeneralEncoder.

This allows, for example, a single parameterised gate to have a fixed parameter, shared across all items in the batch, while all other parameters can be dynamically provided. The change is fully backwards compatible, and no changes need to be made to existing encoder definitions.

This supports ansatze which have some fixed gates which aren't Paulis or other named operators.
I expect this could already be achieved by subclassing existing operators, but I think this imposes less effort on the user.

Opening this to gauge interest. Shall update documentation and tests pending feedback.

>>> import torchquantum as tq
>>> import torch
>>> 
>>> 
>>> enc1 = [{'input_idx': [0], 'func': 'rx', 'wires': [0]},
...        {'input_idx': [1], 'func': 'rx', 'wires': [1]},
...        {'input_idx': [2], 'func': 'rx', 'wires': [2]},
...        {'input_idx': [3], 'func': 'rx', 'wires': [3]}]
>>> 
>>> # Same as above, but now with 0.3 provided to the first RX gate always
>>> enc2 = [{'params': [0.3], 'func': 'rx', 'wires': [0]},
...        {'input_idx': [0], 'func': 'rx', 'wires': [1]},
...        {'input_idx': [1], 'func': 'rx', 'wires': [2]},
...        {'input_idx': [2], 'func': 'rx', 'wires': [3]}]
>>> 
>>> 
>>> e1 = tq.GeneralEncoder(enc1)
>>> e2 = tq.GeneralEncoder(enc2)
>>> 
>>> qdev = tq.QuantumDevice(4)
>>> 
>>> 
>>> params1 = torch.tensor([[0.3, 0.4, 0.5, 0.6], [0.3, 0, 0, 0]])
>>> params2 = torch.tensor([[0.4, 0.5, 0.6], [0, 0, 0]])
>>> 
>>> qdev.reset_states(2)
>>> e1(qdev, params1)
>>> print(qdev.get_states_1d())
tensor([[ 0.8970+0.0000j,  0.0000-0.2775j,  0.0000-0.2290j, -0.0709+0.0000j,
          0.0000-0.1818j, -0.0562+0.0000j, -0.0464+0.0000j,  0.0000+0.0144j,
          0.0000-0.1356j, -0.0419+0.0000j, -0.0346+0.0000j,  0.0000+0.0107j,
         -0.0275+0.0000j,  0.0000+0.0085j,  0.0000+0.0070j,  0.0022+0.0000j],
        [ 0.9888+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,
          0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,
          0.0000-0.1494j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,
          0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j]])
>>> 
>>> qdev.reset_states(2)
>>> e2(qdev, params2)
>>> print(qdev.get_states_1d())
tensor([[ 0.8970+0.0000j,  0.0000-0.2775j,  0.0000-0.2290j, -0.0709+0.0000j,
          0.0000-0.1818j, -0.0562+0.0000j, -0.0464+0.0000j,  0.0000+0.0144j,
          0.0000-0.1356j, -0.0419+0.0000j, -0.0346+0.0000j,  0.0000+0.0107j,
         -0.0275+0.0000j,  0.0000+0.0085j,  0.0000+0.0070j,  0.0022+0.0000j],
        [ 0.9888+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,
          0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,
          0.0000-0.1494j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,
          0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j,  0.0000+0.0000j]])

@01110011011101010110010001101111 01110011011101010110010001101111 changed the base branch from main to dev March 24, 2024 18:01
@nikhilkhatri
Copy link
Author

Hi @GenericP3rson, could I get a first-pass review on this please?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, on an initial review, this looks solid to me! I want to run it through the rest of the team before officially merging; we’ll meet sometime this week so you’ll get a more definite response by then!

Also, as you mentioned, if you could add some tests, that would be great!

@nikhilkhatri
Copy link
Author

I've added the tests now @GenericP3rson !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, thanks so much for the updates! Will merge in!

@01110011011101010110010001101111 01110011011101010110010001101111 merged commit f6a074b into mit-han-lab:dev Nov 18, 2024
4 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants