Skip to main content

RoCoDA Alpha 0.1

Project description

RoCoDA Alpha 0.1

Counterfactual Data Augmentation for Data-Efficient Robot Learning from Demonstrations

Project Description for RoCoDA here

To install via uv (recommend)

uv venv
source .venv/bin/activate
uv pip install rocoda

Examples

Instantiate DataPrep Object:

#NOTE: all data preparation is done in-place on disk
from rocoda.data.prep import DataPrep
from task import StackThree_D0 #registers the task

data = DataPrep(
    hdf5_path="stack_three.hdf5", 
    env_type="robosuite")

Render a Frame:

from PIL import Image

img = data.render_demo_state(0, 10, camera_name="agentview")
img = Image.fromarray(img)
img = img.transpose(Image.FLIP_TOP_BOTTOM)
display(img)

Define Subtask Boundaries via Heuristic:

from rocoda.environments.robosuite import RobosuiteEnvironment
from typing import Mapping

# Define subtask heuristic for StackThree
def heuristic(rocoda_env:RobosuiteEnvironment)->Mapping[str,bool]:

    env = rocoda_env.get_underlying_env() #gets the robosuite env

    signals = {
        "grasp_1": False,
        "stack_1": False,
        "grasp_2": False,
        "stack_2": False
    }

    signals["grasp_1"] = env._check_grasp(gripper=env.robots[0].gripper, object_geoms=env.cubeA)

    signals["stack_1"] = env._check_cubeA_stacked()

    signals["grasp_2"] = env._check_grasp(gripper=env.robots[0].gripper, object_geoms=env.cubeC)

    return signals
            
# Apply subtask termination heuristic
data.apply_subtask_term_heuristic( ["grasp_1","stack_1","grasp_2","stack_2"], heuristic )

Define the Causal Groups:

# Expects [{subtask_name: {object_name: [object_name, ...]}}] * num_robots

causality = {
    "grasp_1": [("robot0", "cubeA")],
    "stack_1": [("robot0", "cubeA", "cubeB")],
    "grasp_2": [("robot0", "cubeC"),("cubeA","cubeB")],
    "stack_2": [("robot0", "cubeA", "cubeB", "cubeC")]
}

# Set the causality groups for the dataset
data.set_causal_groups(causality)

Run Causal Augmentation:

from rocoda.augments import causal

causal.augment(
    filepath="stack_three.hdf5",
    num_new_episodes=10,
    across_demos=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rocoda-0.1.1.tar.gz (11.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rocoda-0.1.1-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file rocoda-0.1.1.tar.gz.

File metadata

  • Download URL: rocoda-0.1.1.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.5

File hashes

Hashes for rocoda-0.1.1.tar.gz
Algorithm Hash digest
SHA256 7ae9f5b2084767c2358029fec0401b60adb4540657f85a63ae0c3be77871165d
MD5 402718f1279ff5d9a7cd02debf8be1e0
BLAKE2b-256 153a126ec7f9ff61c96a1c146f4360ca944bcb28db5b2e8fac6f36a8158f9f87

See more details on using hashes here.

File details

Details for the file rocoda-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: rocoda-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.5

File hashes

Hashes for rocoda-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5f9618d99cf2b0a55250e48bae17ad87120b6e611fa213b33b05af3e13457fd9
MD5 f28b0382ac0e0c44599c8c5622b8a12a
BLAKE2b-256 f9f27f01f650b10742ac83efedd2a9544d4a226d3f1fe5b5d593aba0f525d301

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page