Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Protocol refactor #46

Open
wants to merge 7 commits into
base: cart-pole-3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ venv
*$py.class

# Protobuf auto-generated bindings
firmware/src/protocol.pb.[ch]
firmware/src/nanopb.pb.[ch]
cartpole/device/protocol_pb2.py
cartpole/device/nanopb_pb2.py
firmware/src/proto
cartpole/device/proto
60 changes: 42 additions & 18 deletions cartpole/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,46 @@

def parse_args():
common = argparse.ArgumentParser(
prog='cartpole',
description='cartpole control experiments'
prog='cartpole', description='cartpole control experiments'
)

subparsers = common.add_subparsers(title='commands', dest='command', required=True, help='command help')
subparsers = common.add_subparsers(
title='commands', dest='command', required=True, help='command help'
)

# common arguments

common.add_argument('-S', '--simulation', action='store_true', help='simulation mode')
common.add_argument(
'-S', '--simulation', action='store_true', help='simulation mode'
)
common.add_argument('-c', '--config', type=str, help='cartpole yaml config file')
common.add_argument('-m', '--mcap', type=str, default='', help='mcap log file')
common.add_argument('-a', '--advance', type=float, default=0.01, help='advance simulation time (seconds)')
common.add_argument(
'-a',
'--advance',
type=float,
default=0.01,
help='advance simulation time (seconds)',
)

# eval arguments
eval = subparsers.add_parser('eval', help='system identification')

eval.add_argument('-d', '--duration', type=float, default=10.0, help='experiment duration (seconds)')
eval.add_argument(
'-d',
'--duration',
type=float,
default=10.0,
help='experiment duration (seconds)',
)
eval.add_argument('-O', '--output', type=str, help='output yaml config file')

return common.parse_args()


def evaluate(device: CartPoleBase, config: Config, args: argparse.Namespace) -> None:
log.info('parameters evaluation')
random.seed(0)

position_margin = 0.01
position_tolerance = 0.005
Expand All @@ -48,8 +64,10 @@ def evaluate(device: CartPoleBase, config: Config, args: argparse.Namespace) ->

log.info(f'run calibration session for {duration:.2f} seconds')
device.reset()
time.sleep(5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Это чтоб палочка успокоилась?


start = device.get_state()
print(start)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Это потом подчистите?

state = start

target = Target(position=0, velocity=0, acceleration=0)
Expand All @@ -59,25 +77,29 @@ def evaluate(device: CartPoleBase, config: Config, args: argparse.Namespace) ->
state = device.get_state()

if abs(target.position - state.cart_position) < position_tolerance:
position = random.uniform(position_max/2, position_max)
position = random.uniform(position_max / 2, position_max)
target.position = position if target.position < 0 else -position
target.velocity = random.uniform(velocity_max/2, velocity_max)
target.acceleration = random.uniform(acceleration_max/2, acceleration_max)
target.velocity = random.uniform(velocity_max / 2, velocity_max)
target.acceleration = random.uniform(acceleration_max / 2, acceleration_max)

log.info(f'target {target}')
device.set_target(target)
state = device.set_target(target)

log.publish('/cartpole/state', state, state.stamp)
log.publish('/cartpole/target', target, state.stamp)
log.publish('/cartpole/info', device.get_info(), state.stamp)
Comment on lines -70 to -72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

А почему решили убрать stamp?

log.publish('/cartpole/state', state)
log.publish('/cartpole/target', target)
# log.publish('/cartpole/info', device.get_info(), state.stamp)

states.append(state)
device.advance(advance)

if args.simulation:
time.sleep(advance) # simulate real time
time.sleep(advance) # simulate real time

if state.error:
print('ERR', state.error)

log.info(f'find parameters')
print(len(states))
parameters = find_parameters(states, config.parameters.gravity)

log.info(f'parameters: {parameters}')
Expand Down Expand Up @@ -106,24 +128,26 @@ def main():

if args.simulation:
log.info('simulation mode')
device = Simulator(integration_step=min(args.advance/20, 0.001))
device = Simulator(integration_step=min(args.advance / 20, 0.001))
else:
raise NotImplementedError()

from cartpole.device import CartPoleDevice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

МБ сверху лучше?


device = CartPoleDevice(hard_reset=True)

if args.config:
log.debug(f'config file: {args.config}')
config = Config.from_yaml_file(args.config)
else:
log.warning('no config file specified, using defaults')
config = Config()


device.set_config(config)

if args.command == 'eval':
evaluate(device, config, args)
else:
raise NotImplementedError()


if __name__ == "__main__":
main()
29 changes: 7 additions & 22 deletions cartpole/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch

from pydantic import BaseModel, Field
from typing import Any
from typing import Any, Optional

import json
import yaml
import time


class Limits(BaseModel):
Expand Down Expand Up @@ -152,7 +153,7 @@ class State(BaseModel):
pole_angle - absolute accumulated pole angle (rad)
pole_angular_velocity - pole angular velocity (rad/s)

stamp - system time stamp (s)
stamp - system time stamp in seconds
error - system error code
'''

Expand All @@ -163,25 +164,9 @@ class State(BaseModel):
pole_angle: float = 0.0
pole_angular_velocity: float = 0.0

stamp: float = 0.0
stamp: float = Field(default_factory=time.perf_counter)
error: Error = Error.NO_ERROR

class Config:
@staticmethod
def json_schema_extra(schema: Any, model: Any) -> None:
# make schema lightweight
schema.pop('definitions', None)

properties = schema['properties']
for name in properties:
properties[name].pop('title', None)

# simplify schema for foxglove
properties['error'] = {
'type': 'integer',
'enum': [e.value for e in Error]
}

def validate(self, config: Config) -> None:
'''
Validates state against limits.
Expand Down Expand Up @@ -240,9 +225,9 @@ class Target(BaseModel):
If velocity/accleration is not specified (absolute value needed), use control limit as a default.
'''

position: float | None = None
velocity: float | None = None
acceleration: float | None = None
position: Optional[float] = None
velocity: Optional[float] = None
acceleration: Optional[float] = None

def acceleration_or(self, default: float) -> float:
'''
Expand Down
2 changes: 1 addition & 1 deletion cartpole/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from cartpole.device._device import *
from cartpole.device.device import CartPoleDevice
79 changes: 0 additions & 79 deletions cartpole/device/_device.py

This file was deleted.

Loading