Skip to content

Commit

Permalink
fix: changed github -> pip for git+ installs
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrickAlphaC committed Sep 22, 2024
1 parent b04da78 commit eb3a75a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 58 deletions.
47 changes: 21 additions & 26 deletions moccasin/commands/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import requests # type: ignore
import tomli_w
from packaging.requirements import Requirement
from packaging.requirements import Requirement, InvalidRequirement
from tqdm import tqdm

from moccasin.config import get_config
Expand Down Expand Up @@ -54,36 +54,24 @@ def main(args: Namespace):


def classify_dependency(dependency: str) -> DependencyType:
dependency = dependency.strip().strip("\"'")
dependency = dependency.strip().strip("'\"")

# GitHub patterns
github_shorthand = r"^([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)(@[a-zA-Z0-9_.-]+)?$"
github_url = r"^(git\+)?(https?://github\.com/[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+)(\.git)?(@[a-zA-Z0-9_.-]+)?$"
github_url = r"^https://github\.com/[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$"

if re.match(github_shorthand, dependency) or re.match(github_url, dependency):
return DependencyType.GITHUB

return DependencyType.PIP


def extract_org_and_package(path: str) -> tuple[str, str]:
if "@" in path:
path, _ = path.split("@", 1)
path = path.strip().strip("'\"")
github_url_pattern = r"(?:https?://github\.com/|git\+https?://github\.com/)([\w-]+)/([\w-]+)(?:\.git)?"

# Check if it's a full GitHub URL
match = re.match(github_url_pattern, path)
if match:
return match.group(1), match.group(2)

# If it's not a full URL, treat it as org/package
parts = path.split("/")
if len(parts) >= 2:
return parts[0], parts[1]

# If we can't parse it, raise an exception
raise ValueError(f"Unable to extract organization and package from '{path}'")
def preprocess_requirement(package: str) -> str:
package = package.strip().strip("'\"")
git_url_pattern = r"^git\+https?://.*"
if re.match(git_url_pattern, package):
package_name = package.split("/")[-1].replace(".git", "")
return package_name
return package


# Much of this code thanks to brownie
Expand All @@ -99,7 +87,7 @@ def _github_installs(
else:
path = package_id
version = None # We'll fetch the latest version later
org, repo = extract_org_and_package(path)
org, repo = path.split("/")
except ValueError:
raise ValueError(
"Invalid package ID. Must be given as ORG/REPO[@VERSION]"
Expand Down Expand Up @@ -270,15 +258,22 @@ def _write_dependencies(new_package_ids: list[str], dependency_type: DependencyT
config = get_config()
dependencies = config.get_dependencies()
typed_dependencies = [
dep for dep in dependencies if classify_dependency(dep) == dependency_type
preprocess_requirement(dep)
for dep in dependencies
if classify_dependency(dep) == dependency_type
]

to_delete = set()
updated_packages = set()

if dependency_type == DependencyType.PIP:
for package in new_package_ids:
package_req = Requirement(package)
try:
processed_package = preprocess_requirement(package)
package_req = Requirement(processed_package)
except InvalidRequirement:
logger.warning(f"Invalid requirement format for package: {package}")
continue
for dep in typed_dependencies:
dep_req = Requirement(dep)
if dep_req.name == package_req.name:
Expand Down Expand Up @@ -329,7 +324,7 @@ def from_string(cls, dep_string: str) -> "GitHubDependency":
else:
path, version = dep_string, None

org, repo = extract_org_and_package(path)
org, repo = str(path).split("/")
return cls(org, repo, version)

def format_no_version(self) -> str:
Expand Down
41 changes: 9 additions & 32 deletions tests/unit/test_unit_install.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from moccasin.commands.install import (
DependencyType,
classify_dependency,
extract_org_and_package,
preprocess_requirement,
)


Expand All @@ -15,6 +15,11 @@ def test_classify_dependency_pip_version():
assert classify_dependency(dep) == DependencyType.PIP


def test_classify_dependency_pip_long():
dep = '"git+https://github.com/pcaversaccio/snekmate.git"'
assert classify_dependency(dep) == DependencyType.PIP


def test_classify_dependency_git():
dep = "pcaversaccio/snekmate"
assert classify_dependency(dep) == DependencyType.GITHUB
Expand All @@ -25,39 +30,11 @@ def test_classify_dependency_git_version():
assert classify_dependency(dep) == DependencyType.GITHUB


def test_classify_dependency_git_long():
dep = '"git+https://github.com/pcaversaccio/snekmate.git"'
assert classify_dependency(dep) == DependencyType.GITHUB


def test_classify_dependency_git_no_git():
dep = "https://github.com/pcaversaccio/snekmate"
assert classify_dependency(dep) == DependencyType.GITHUB


def test_extra_from_github():
dep = '"git+https://github.com/pcaversaccio/snekmate.git"'
org, package = extract_org_and_package(dep)
assert org == "pcaversaccio"
assert package == "snekmate"


def test_extra_from_github_shorthand():
dep = "pcaversaccio/snekmate"
org, package = extract_org_and_package(dep)
assert org == "pcaversaccio"
assert package == "snekmate"


def test_extra_from_github_version():
dep = "pcaversaccio/[email protected]"
org, package = extract_org_and_package(dep)
assert org == "pcaversaccio"
assert package == "snekmate"


def test_extra_from_github_no_git():
dep = "https://github.com/pcaversaccio/snekmate"
org, package = extract_org_and_package(dep)
assert org == "pcaversaccio"
assert package == "snekmate"
def test_preprocess_requirement():
req = '"git+https://github.com/pcaversaccio/snekmate.git"'
assert preprocess_requirement(req) == "snekmate"

0 comments on commit eb3a75a

Please sign in to comment.