Robotics Transformer Inference in Tensorflow. RT-1, RT-2, RT-X, PALME.
Project description
Library for Robotic Transformers. RT-1 and RT-X-1.
Installation:
Requirements: python >= 3.9
Using PyPI
pip install robo-transformers
From Source
Clone this repo:
git clone https://github.com/sebbyjp/robo_transformers.git
cd robo_transformers
Use poetry
pip install poetry && poetry config virtualenvs.in-project true
Install dependencies:
poetry install
source .venv/bin/activate
Run RT-1 Inference On Demo Images.
python -m robo_transformers.rt1.rt1_inference
See options:
python -m robo_transformers.rt1.rt1_inference --help
Notes
action, next_policy_state = model.act(time_step, curr_policy_state)
policy state is internal state of network:
In this case it is a 6-frame window of past observations,actions and the index in time.
{'action_tokens': ArraySpec(shape=(6, 11, 1, 1), dtype=dtype('int32'), name='action_tokens'),
'image': ArraySpec(shape=(6, 256, 320, 3), dtype=dtype('uint8'), name='image'),
'step_num': ArraySpec(shape=(1, 1, 1, 1), dtype=dtype('int32'), name='step_num'),
't': ArraySpec(shape=(1, 1, 1, 1), dtype=dtype('int32'), name='t')}
time_step is the input from the environment:
{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0),
'observation': {'base_pose_tool_reached': ArraySpec(shape=(7,), dtype=dtype('float32'), name='base_pose_tool_reached'),
'gripper_closed': ArraySpec(shape=(1,), dtype=dtype('float32'), name='gripper_closed'),
'gripper_closedness_commanded': ArraySpec(shape=(1,), dtype=dtype('float32'), name='gripper_closedness_commanded'),
'height_to_bottom': ArraySpec(shape=(1,), dtype=dtype('float32'), name='height_to_bottom'),
'image': ArraySpec(shape=(256, 320, 3), dtype=dtype('uint8'), name='image'),
'natural_language_embedding': ArraySpec(shape=(512,), dtype=dtype('float32'), name='natural_language_embedding'),
'natural_language_instruction': ArraySpec(shape=(), dtype=dtype('O'), name='natural_language_instruction'),
'orientation_box': ArraySpec(shape=(2, 3), dtype=dtype('float32'), name='orientation_box'),
'orientation_start': ArraySpec(shape=(4,), dtype=dtype('float32'), name='orientation_in_camera_space'),
'robot_orientation_positions_box': ArraySpec(shape=(3, 3), dtype=dtype('float32'), name='robot_orientation_positions_box'),
'rotation_delta_to_go': ArraySpec(shape=(3,), dtype=dtype('float32'), name='rotation_delta_to_go'),
'src_rotation': ArraySpec(shape=(4,), dtype=dtype('float32'), name='transform_camera_robot'),
'vector_to_go': ArraySpec(shape=(3,), dtype=dtype('float32'), name='vector_to_go'),
'workspace_bounds': ArraySpec(shape=(3, 3), dtype=dtype('float32'), name='workspace_bounds')},
'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),
'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')}
action:
{'base_displacement_vector': BoundedArraySpec(shape=(2,), dtype=dtype('float32'), name='base_displacement_vector', minimum=-1.0, maximum=1.0),
'base_displacement_vertical_rotation': BoundedArraySpec(shape=(1,), dtype=dtype('float32'), name='base_displacement_vertical_rotation', minimum=-3.1415927410125732, maximum=3.1415927410125732),
'gripper_closedness_action': BoundedArraySpec(shape=(1,), dtype=dtype('float32'), name='gripper_closedness_action', minimum=-1.0, maximum=1.0),
'rotation_delta': BoundedArraySpec(shape=(3,), dtype=dtype('float32'), name='rotation_delta', minimum=-1.5707963705062866, maximum=1.5707963705062866),
'terminate_episode': BoundedArraySpec(shape=(3,), dtype=dtype('int32'), name='terminate_episode', minimum=0, maximum=1),
'world_vector': BoundedArraySpec(shape=(3,), dtype=dtype('float32'), name='world_vector', minimum=-1.0, maximum=1.0)}
TODO:
- Render action, policy_state, observation specs in something prettier like pandas data frame.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for robo_transformers-0.1.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ae45f6ec6917cc6312636425923f164e7d767af02c17f40612a9311b903856b8 |
|
MD5 | adb88637a3d0a689574752ba7d28935f |
|
BLAKE2b-256 | 719298c2075d740532ba01e8f8b5a03495ba3e1f95df2f7170eefd103df2154f |