Skip to content

Commit

Permalink
Merge pull request #79 from stratosphereips/cyst-integration
Browse files Browse the repository at this point in the history
Cyst integration
  • Loading branch information
ondrej-lukas authored Jan 27, 2025
2 parents d15cf5c + 77e5edd commit 3693e0e
Show file tree
Hide file tree
Showing 22 changed files with 230 additions and 425 deletions.
66 changes: 41 additions & 25 deletions agents/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from os import path
sys.path.append(path.dirname(path.dirname(path.dirname( path.abspath(__file__) ))))
#with the path fixed, we can import now
from env.game_components import Action, ActionType, GameState, Observation, IP, Network
from AIDojoCoordinator.game_components import Action, ActionType, GameState, Observation, IP, Network
import ipaddress

def generate_valid_actions_concepts(state: GameState)->list:
Expand All @@ -22,12 +22,12 @@ def generate_valid_actions_concepts(state: GameState)->list:
# TODO ADD neighbouring networks
# Only scan local networks from local hosts
if network.is_private() and source_host.is_private():
valid_actions.add(Action(ActionType.ScanNetwork, params={"target_network": network, "source_host": source_host,}))
valid_actions.add(Action(ActionType.ScanNetwork, parameters={"target_network": network, "source_host": source_host,}))
# Service Scans
for host in state.known_hosts:
# Do not try to scan a service from hosts outside local networks towards local networks
if host.is_private() and source_host.is_private():
valid_actions.add(Action(ActionType.FindServices, params={"target_host": host, "source_host": source_host,}))
valid_actions.add(Action(ActionType.FindServices, parameters={"target_host": host, "source_host": source_host,}))
# Service Exploits
for host, service_list in state.known_services.items():
# Only exploit local services from local hosts
Expand All @@ -37,52 +37,65 @@ def generate_valid_actions_concepts(state: GameState)->list:
for service in service_list:
# Do not consider local services, which are internal to the host
if not service.is_local:
valid_actions.add(Action(ActionType.ExploitService, params={"target_host": host,"target_service": service,"source_host": source_host,}))
valid_actions.add(Action(ActionType.ExploitService, parameters={"target_host": host,"target_service": service,"source_host": source_host,}))
# Find Data Scans
for host in state.controlled_hosts:
valid_actions.add(Action(ActionType.FindData, params={"target_host": host, "source_host": host}))
valid_actions.add(Action(ActionType.FindData, parameters={"target_host": host, "source_host": host}))

# Data Exfiltration
for source_host, data_list in state.known_data.items():
for data in data_list:
for target_host in state.controlled_hosts:
if target_host != source_host:
valid_actions.add(Action(ActionType.ExfiltrateData, params={"target_host": target_host, "source_host": source_host, "data": data}))
valid_actions.add(Action(ActionType.ExfiltrateData, parameters={"target_host": target_host, "source_host": source_host, "data": data}))
return list(valid_actions)

def generate_valid_actions(state: GameState)->list:
def generate_valid_actions(state: GameState, include_blocks=False)->list:
"""Function that generates a list of all valid actions in a given state"""
valid_actions = set()
def is_fw_blocked(state, src_ip, dst_ip)->bool:
blocked = False
try:
blocked = dst_ip in state.known_blocks[src_ip]
except KeyError:
pass #this src ip has no known blocks
return blocked

for src_host in state.controlled_hosts:
#Network Scans
for network in state.known_networks:
# TODO ADD neighbouring networks
valid_actions.add(Action(ActionType.ScanNetwork, params={"target_network": network, "source_host": src_host,}))
valid_actions.add(Action(ActionType.ScanNetwork, parameters={"target_network": network, "source_host": src_host,}))
# Service Scans
for host in state.known_hosts:
valid_actions.add(Action(ActionType.FindServices, params={"target_host": host, "source_host": src_host,}))
if not is_fw_blocked(state, src_host,host):
valid_actions.add(Action(ActionType.FindServices, parameters={"target_host": host, "source_host": src_host,}))
# Service Exploits
for host, service_list in state.known_services.items():
for service in service_list:
valid_actions.add(Action(ActionType.ExploitService, params={"target_host": host,"target_service": service,"source_host": src_host,}))
if not is_fw_blocked(state, src_host,host):
for service in service_list:
valid_actions.add(Action(ActionType.ExploitService, parameters={"target_host": host,"target_service": service,"source_host": src_host,}))
# Data Scans
for host in state.controlled_hosts:
valid_actions.add(Action(ActionType.FindData, params={"target_host": host, "source_host": host}))
if not is_fw_blocked(state, src_host,host):
valid_actions.add(Action(ActionType.FindData, parameters={"target_host": host, "source_host": host}))

# Data Exfiltration
for src_host, data_list in state.known_data.items():
for data in data_list:
for trg_host in state.controlled_hosts:
if trg_host != src_host:
valid_actions.add(Action(ActionType.ExfiltrateData, params={"target_host": trg_host, "source_host": src_host, "data": data}))
if not is_fw_blocked(state, src_host,trg_host):
valid_actions.add(Action(ActionType.ExfiltrateData, parameters={"target_host": trg_host, "source_host": src_host, "data": data}))

# BlockIP
for src_host in state.controlled_hosts:
for target_host in state.controlled_hosts:
for blocked_ip in state.known_hosts:
valid_actions.add(Action(ActionType.BlockIP, {"target_host":target_host, "source_host":src_host, "blocked_host":blocked_ip}))


if include_blocks:
# BlockIP
if include_blocks:
for src_host in state.controlled_hosts:
for target_host in state.controlled_hosts:
if not is_fw_blocked(state, src_host,target_host):
for blocked_ip in state.known_hosts:
valid_actions.add(Action(ActionType.BlockIP, {"target_host":target_host, "source_host":src_host, "blocked_host":blocked_ip}))
return list(valid_actions)

def state_as_ordered_string(state:GameState)->str:
Expand All @@ -97,6 +110,9 @@ def state_as_ordered_string(state:GameState)->str:
ret += "},data:{"
for host in sorted(state.known_data.keys()):
ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_data[host])])}]"
ret += "},blocks:{"
for host in sorted(state.known_blocks.keys()):
ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_blocks[host])])}]"
ret += "}"
return ret

Expand Down Expand Up @@ -480,7 +496,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
if src_host_concept == host_concept:
new_src_host = concept_mapping['controlled_hosts'][host_concept]

action = Action(ActionType.ExploitService, params={"target_host": new_target_host, "target_service": new_target_service, "source_host": new_src_host})
action = Action(ActionType.ExploitService, parameters={"target_host": new_target_host, "target_service": new_target_service, "source_host": new_src_host})

elif action._type == ActionType.ExfiltrateData:
# parameters = {
Expand Down Expand Up @@ -508,7 +524,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
data_concept = action.parameters['data']
new_data = data_concept

action = Action(ActionType.ExfiltrateData, params={"target_host": new_target_host, "source_host": new_src_host, "data": new_data})
action = Action(ActionType.ExfiltrateData, parameters={"target_host": new_target_host, "source_host": new_src_host, "data": new_data})

elif action._type == ActionType.FindData:
# parameters = {
Expand All @@ -529,7 +545,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
if src_host_concept == host_concept:
new_src_host = concept_mapping['controlled_hosts'][host_concept]

action = Action(ActionType.FindData, params={"target_host": new_target_host, "source_host": new_src_host})
action = Action(ActionType.FindData, parameters={"target_host": new_target_host, "source_host": new_src_host})

elif action._type == ActionType.ScanNetwork:
target_net_concept = action.parameters['target_network']
Expand All @@ -545,7 +561,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
for host_concept in concept_mapping['controlled_hosts']:
if src_host_concept == host_concept:
new_src_host = concept_mapping['controlled_hosts'][host_concept]
action = Action(ActionType.ScanNetwork, params={"source_host": new_src_host, "target_network": new_target_network} )
action = Action(ActionType.ScanNetwork, parameters={"source_host": new_src_host, "target_network": new_target_network} )

elif action._type == ActionType.FindServices:
# parameters = {
Expand All @@ -565,6 +581,6 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
for host_concept in concept_mapping['controlled_hosts']:
if src_host_concept == host_concept:
new_src_host = concept_mapping['controlled_hosts'][host_concept]
action = Action(ActionType.FindServices, params={"source_host": new_src_host, "target_host": new_target_host} )
action = Action(ActionType.FindServices, parameters={"source_host": new_src_host, "target_host": new_target_host} )

return action
19 changes: 8 additions & 11 deletions agents/attackers/concepts_q_learning/conceptual_q_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
# Arti
# Sebastian Garcia. [email protected]
import sys
from os import path, makedirs
import numpy as np
import random
import pickle
import argparse
import logging
# This is used so the agent can see the environment and game component
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__) ) ) ))))
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__) ))))
import mlflow
import subprocess
from os import path, makedirs
from AIDojoCoordinator.game_components import Action, Observation, GameState, AgentStatus

# This is used so the agent can see the environment and game component
# with the path fixed, we can import now
from env.game_components import Action, Observation, GameState
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__) ))))
from base_agent import BaseAgent
from agent_utils import generate_valid_actions, state_as_ordered_string, convert_concepts_to_actions, convert_ips_to_concepts
import mlflow
import subprocess


class QAgent(BaseAgent):

Expand Down Expand Up @@ -383,15 +380,15 @@ def play_game(self, observation_ip, episode_num, testing=False):
test_end = test_observation.end
test_info = test_observation.info

if test_info and test_info['end_reason'] == 'blocked':
if test_info and test_info['end_reason'] == AgentStatus.Fail:
test_detected +=1
test_num_detected_steps += [num_steps]
test_num_detected_returns += [reward]
elif test_info and test_info['end_reason'] == 'goal_reached':
elif test_info and test_info['end_reason'] == AgentStatus.Success:
test_wins += 1
test_num_win_steps += [num_steps]
test_num_win_returns += [reward]
elif test_info and test_info['end_reason'] == 'max_steps':
elif test_info and test_info['end_reason'] == AgentStatus.TimeoutReached:
test_max_steps += 1
test_num_max_steps_steps += [num_steps]
test_num_max_steps_returns += [reward]
Expand Down
4 changes: 2 additions & 2 deletions agents/attackers/double_q_learning/double_q_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import logging
from torch.utils.tensorboard import SummaryWriter
import time
from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.game_components import Action, Observation, GameState
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.game_components import Action, Observation, GameState

class DoubleQAgent:

Expand Down
2 changes: 1 addition & 1 deletion agents/attackers/gnn_reinforce/gnn_REINFORCE_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))

#with the path fixed, we can import now
from env.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
Expand Down
9 changes: 2 additions & 7 deletions agents/attackers/interactive_tui/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,8 @@
import jinja2
from tenacity import retry, stop_after_attempt

sys.path.append(
path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
)
from env.game_components import (
ActionType,
Observation,
)

from AIDojoCoordinator.game_components import ActionType, Observation

sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))

Expand Down
31 changes: 11 additions & 20 deletions agents/attackers/interactive_tui/interactive_tui.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
#
# Author: Maria Rigaki - [email protected]
#
from textual.app import App, ComposeResult, Widget
from textual.widgets import Tree, Button, RichLog, Select, Input
from textual.containers import Vertical, VerticalScroll, Horizontal
from textual.validation import Function
from textual import on
from textual.reactive import reactive

import sys
from os import path
import os
import logging
import ipaddress
import argparse
import asyncio

from textual.app import App, ComposeResult, Widget
from textual.widgets import Tree, Button, RichLog, Select, Input
from textual.containers import Vertical, VerticalScroll, Horizontal
from textual.validation import Function
from textual import on
from textual.reactive import reactive
from assistant import LLMAssistant

# This is used so the agent can see the environment and game components
sys.path.append(
path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
)
from env.game_components import Network, IP
from env.game_components import ActionType, Action, GameState, Observation
from AIDojoCoordinator.game_components import Network, IP, ActionType, Action, GameState, Observation

# This is used so the agent can see the BaseAgent
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
Expand Down Expand Up @@ -558,7 +549,7 @@ def generate_action(self, state: GameState) -> Action:
self.network_input[:-3], mask=int(self.network_input[-2:])
),
}
action = Action(action_type=self.next_action, params=parameters)
action = Action(action_type=self.next_action, parameters=parameters)
else:
self.notify("Please provide valid inputs", severity="error")
elif self.next_action in [ActionType.FindServices, ActionType.FindData]:
Expand All @@ -567,7 +558,7 @@ def generate_action(self, state: GameState) -> Action:
"source_host": IP(self.src_host_input),
"target_host": IP(self.target_host_input),
}
action = Action(action_type=self.next_action, params=parameters)
action = Action(action_type=self.next_action, parameters=parameters)
else:
self.notify("Please provide valid inputs", severity="error")
elif self.next_action == ActionType.ExploitService:
Expand All @@ -582,7 +573,7 @@ def generate_action(self, state: GameState) -> Action:
"target_service": service,
}
action = Action(
action_type=self.next_action, params=parameters
action_type=self.next_action, parameters=parameters
)
break
else:
Expand All @@ -600,7 +591,7 @@ def generate_action(self, state: GameState) -> Action:
"data": datum,
}
action = Action(
action_type=self.next_action, params=parameters
action_type=self.next_action, parameters=parameters
)
else:
parameters = self.data_input
Expand Down
4 changes: 2 additions & 2 deletions agents/attackers/llm/llm_agent-2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))


from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.game_components import ActionType, Action, IP, Data, Network, Service
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service

import openai
from tenacity import retry, stop_after_attempt
Expand Down
4 changes: 2 additions & 2 deletions agents/attackers/llm/llm_agent-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))


from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.game_components import ActionType, Action, IP, Data, Network, Service
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service

import openai
from tenacity import retry, stop_after_attempt
Expand Down
4 changes: 2 additions & 2 deletions agents/attackers/llm/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))


from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.game_components import ActionType, Action, IP, Data, Network, Service
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service

import openai
from tenacity import retry, stop_after_attempt
Expand Down
4 changes: 2 additions & 2 deletions agents/attackers/llm_embed/llm_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from os import path
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))

from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.game_components import Action, ActionType
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.game_components import Action, ActionType

import numpy as np
import torch
Expand Down
4 changes: 2 additions & 2 deletions agents/attackers/llm_embed_dqn/llm_embed_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from os import path
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))

from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.game_components import Action, ActionType
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
from AIDojoCoordinator.game_components import Action, ActionType
from sentence_transformers import SentenceTransformer

import numpy as np
Expand Down
Loading

0 comments on commit 3693e0e

Please sign in to comment.