Skip to content

Commit

Permalink
Check the domains when posting, new route to force check_domain, get …
Browse files Browse the repository at this point in the history
…the check informations in GET
  • Loading branch information
Benjamin Bayart committed Jul 26, 2024
1 parent 29d1ab1 commit 9237e68
Show file tree
Hide file tree
Showing 16 changed files with 333 additions and 37 deletions.
66 changes: 66 additions & 0 deletions src/alembic/versions/a00d7feb5df9_add_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""add errors
Revision ID: a00d7feb5df9
Revises: 7d80d2028a7f
Create Date: 2024-07-26 10:13:59.099634
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = 'a00d7feb5df9'
down_revision: Union[str, None] = '7d80d2028a7f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade(engine_name: str) -> None:
globals()["upgrade_%s" % engine_name]()


def downgrade(engine_name: str) -> None:
globals()["downgrade_%s" % engine_name]()





def upgrade_api() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('domains', sa.Column('errors', sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade_api() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('domains', 'errors')
# ### end Alembic commands ###


def upgrade_dovecot() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def downgrade_dovecot() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def upgrade_postfix() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def downgrade_postfix() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

3 changes: 3 additions & 0 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ def _make_domain(
log.info("- creating the domain")
res = client.post(
"/domains/",
params = {
"no_check": "true",
},
json = {
"name": name,
"features": features,
Expand Down
4 changes: 3 additions & 1 deletion src/dns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .domain import Domain
from .domain import background_check_new_domain, foreground_check_domain, Domain
from .dkim import DkimInfo
from .utils import get_ip_address, make_auth_resolver

__all__ = [
background_check_new_domain,
DkimInfo,
Domain,
foreground_check_domain,
get_ip_address,
make_auth_resolver,
]
74 changes: 54 additions & 20 deletions src/dns/domain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import dns.name
import dns.resolver
import sqlalchemy.orm as orm
Expand All @@ -23,7 +25,7 @@
targets = {
"webmail": "webmail.ox.numerique.gouv.fr.",
"imap": "imap.ox.numerique.gouv.fr.",
"mailbox": "mail.ox.numerique.gouv.fr.",
"mail": "mail.ox.numerique.gouv.fr.",
"smtp": "smtp.ox.numerique.gouv.fr.",
}
#required_mx = "mx.fdn.fr."
Expand All @@ -48,8 +50,8 @@ def __init__(

self.dkim = dkim

def add_err(self, err: str, detail: str = ""):
self.errs.append({"code": err, "detail": detail})
def add_err(self, test: str, err: str, detail: str = ""):
self.errs.append({"test": test, "code": err, "detail": detail})
self.valid = False

def get_auth_resolver(self, domain: str, insist: bool = False) -> dns.resolver.Resolver:
Expand All @@ -64,7 +66,7 @@ def get_auth_resolver(self, domain: str, insist: bool = False) -> dns.resolver.R
def check_exists(self):
resolver = self.get_auth_resolver(self.domain.name)
if resolver is None:
self.add_err("must_exist", f"Le domaine {self.domain.name} n'existe pas")
self.add_err("domain_exist", "must_exist", f"Le domaine {self.domain.name} n'existe pas")
return

def try_cname_for_mx(self):
Expand All @@ -75,10 +77,11 @@ def try_cname_for_mx(self):
print(f"Je trouve un CNAME vers {self.dest_domain}, je le prend comme dest_domain")
self.dest_name = dns.name.from_text(self.dest_domain)
return self.check_mx()
except dns.resolver.NXDOMAIN:
self.add_err("no_mx", "Il n'y a pas d'enregistrement MX ou CNAME sur le domaine")
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
self.add_err("mx", "no_mx", "Il n'y a pas d'enregistrement MX ou CNAME sur le domaine")
return
except Exception:
except Exception as e:
print(f"Unexpected exception while searching for a CNAME for a MX : {e}")
raise

def check_mx(self):
Expand All @@ -87,20 +90,23 @@ def check_mx(self):
print(f"Je cherche un MX pour {self.dest_domain}")
answer = resolver.resolve(self.dest_name, rdtype = "MX")
except dns.resolver.NXDOMAIN:
self.add_err("no_mx", "Il n'y a pas d'enregistrement MX sur le domaine")
print("NXDOMAIN")
self.add_err("mx", "no_mx", "Il n'y a pas d'enregistrement MX sur le domaine")
return
except dns.resolver.NoAnswer:
return self.try_cname_for_mx()
except Exception:
except Exception as e:
print(f"Unexpected exception while searching for MX {e}")
raise

nb_mx = len(answer)
if nb_mx != 1 and False:
self.add_err("one_mx", f"Je veux un seul MX, et j'en trouve {nb_mx}")
self.add_err("mx", "one_mx", f"Je veux un seul MX, et j'en trouve {nb_mx}")
return
mx = str(answer[0].exchange)
if not mx == required_mx:
self.add_err(
"mx",
"wrong_mx",
f"Je veux que le MX du domaine soit {required_mx}, "
f"or je trouve {mx}"
Expand All @@ -124,18 +130,20 @@ def check_cname(self, name):
for origin in origins:
resolver = self.get_auth_resolver(origin)
if resolver is None:
self.add_err(f"no_cname_{name}", f"Il faut un CNAME {origin} qui renvoie vers {target}")
self.add_err(f"cname_{name}", f"no_cname_{name}", f"Il faut un CNAME {origin} qui renvoie vers {target}")
continue
try:
answer = resolver.resolve(origin, rdtype = "CNAME")
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
self.add_err(f"no_cname_{name}", f"Il n'y a pas de CNAME {origin} -> {target}")
self.add_err(f"cname_{name}", f"no_cname_{name}", f"Il n'y a pas de CNAME {origin} -> {target}")
continue
except Exception:
except Exception as e:
print(f"Unexpected exception when searching for a CNAME : {e}")
raise
got_target = str(answer[0].target)
if not got_target == target:
self.add_err(
f"cname_{name}",
f"wrong_cname_{name}",
f"Le CNAME pour {origin} n'est pas bon, "
f"il renvoie vers {got_target} et je veux {target}"
Expand All @@ -147,7 +155,8 @@ def check_spf(self):
answer = resolver.resolve(self.dest_name, rdtype="TXT")
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
answer = []
except Exception:
except Exception as e:
print(f"Unexpected exception when searching for the SPF record : {e}")
raise
found_spf = False
valid_spf = False
Expand All @@ -159,10 +168,10 @@ def check_spf(self):
valid_spf = True
return
if not found_spf:
self.add_err("no_spf", f"Il faut un SPF record, et il doit contenir {required_spf}")
self.add_err("spf", "no_spf", f"Il faut un SPF record, et il doit contenir {required_spf}")
return
if not valid_spf:
self.add_err("wrong_spf", f"Le SPF record ne contient pas {required_spf}")
self.add_err("spf", "wrong_spf", f"Le SPF record ne contient pas {required_spf}")
return

def check_dkim(self):
Expand All @@ -174,7 +183,8 @@ def check_dkim(self):
answer = resolver.resolve(self.dkim_name, rdtype="TXT")
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
answer = []
except Exception:
except Exception as e:
print(f"Unexpected exception when searching for the DKIM record : {e}")
raise
found_dkim = False
valid_dkim = False
Expand All @@ -188,10 +198,10 @@ def check_dkim(self):
valid_dkim = True
return
if not found_dkim:
self.add_err("no_dkim", "Il faut un DKIM record, et il doit contenir la bonne clef")
self.add_err("dkim", "no_dkim", "Il faut un DKIM record, et il doit contenir la bonne clef")
return
if not valid_dkim:
self.add_err("wrong_dkim", "Le DKIM record n'est pas valide (il ne contient pas la bonne clef)")
self.add_err("dkim", "wrong_dkim", "Le DKIM record n'est pas valide (il ne contient pas la bonne clef)")
return

def _check_domain(self) -> bool:
Expand All @@ -205,7 +215,6 @@ def _check_domain(self) -> bool:
self.check_cname("webmail")
if self.domain.has_feature("mailbox"):
self.check_cname("imap")
self.check_cname("mailbox")
self.check_cname("smtp")
self.check_spf()
self.check_dkim()
Expand Down Expand Up @@ -247,3 +256,28 @@ def db_setup(self, db: orm.Session):
dom = None
if dom is None:
dom = sql_postfix.create_alias_domain(db, self.domain.name, self.domain.get_alias_domain())


def foreground_check_domain(db: orm.Session, db_dom: sql_api.DBDomain) -> sql_api.DBDomain:
name = db_dom.name
ck_dom = Domain(db_dom)
ck_dom.check()
if ck_dom.valid:
sql_api.update_domain_state(db, name, "ok")
db_dom = sql_api.update_domain_errors(db, name, None)
else:
sql_api.update_domain_state(db, name, "broken")
db_dom = sql_api.update_domain_errors(db, name, ck_dom.errs)
return db_dom

def background_check_new_domain(name: str):
log = logging.getLogger(__name__)
maker = sql_api.get_maker()
db = maker()
db_dom = sql_api.get_domain(db, name)
if db_dom is None:
log.error("Je ne sais pas vérifier un domaine qui n'existe pas en base")
db.close()
return
db_dom = foreground_check_domain(db, db_dom)
db.close()
3 changes: 1 addition & 2 deletions src/dns/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ def test_domain_check():
ck_dom = domain.Domain(db_dom)
ck_dom.check()
assert ck_dom.valid is False
assert len(ck_dom.errs) == 6
assert len(ck_dom.errs) == 5
codes = [ err["code"] for err in ck_dom.errs ]
assert "wrong_mx" in codes
assert "wrong_cname_webmail" in codes
assert "wrong_cname_imap" in codes
assert "no_cname_mailbox" in codes
assert "no_cname_smtp" in codes
assert "wrong_spf" in codes

Expand Down
2 changes: 2 additions & 0 deletions src/routes/domains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# ruff: noqa: E402
from .check_domain import check_domain
from .get_domain import get_domain
from .get_domains import get_domains
from .post_domain import post_domain

__all__ = [
check_domain,
get_domain,
get_domains,
post_domain,
Expand Down
31 changes: 31 additions & 0 deletions src/routes/domains/check_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging

import fastapi

from ... import auth, dns, sql_api, web_models
from .. import dependencies, routers


@routers.domains.get("/{domain_name}/check")
async def check_domain(
db: dependencies.DependsApiDb,
user: auth.DependsBasicAdmin,
domain_name: str,
) -> web_models.Domain:
log = logging.getLogger(__name__)
perms = user.get_creds()

domain_db = sql_api.get_domain(db, domain_name)
if domain_db is None:
log.info(f"Domain {domain_name} not found.")
raise fastapi.HTTPException(status_code=404, detail="Domain not found")

if not perms.can_read(domain_name):
log.info(f"Permission denied on domain {domain_name} for user.")
raise fastapi.HTTPException(status_code=401, detail="Not authorized.")

domain_db = dns.foreground_check_domain(db, domain_db)
log.info(f"Domain state after check is {domain_db.state}")
assert domain_db.state in [ "ok", "broken" ]

return web_models.Domain.from_db(domain_db)
9 changes: 6 additions & 3 deletions src/routes/domains/post_domain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import fastapi

from ... import auth, oxcli, sql_api, web_models
from ... import auth, dns, oxcli, sql_api, web_models
from .. import dependencies, routers


@routers.domains.post("/", status_code=201)
async def post_domain(
db: dependencies.DependsApiDb,
user: auth.DependsBasicAdmin,
domain: web_models.Domain,
domain: web_models.CreateDomain,
bg: fastapi.BackgroundTasks,
no_check: str = "false",
) -> web_models.Domain:
if "webmail" in domain.features and domain.context_name is None:
raise fastapi.HTTPException(
Expand Down Expand Up @@ -44,7 +46,8 @@ async def post_domain(
imap_domains=domain.imap_domains,
smtp_domains=domain.smtp_domains,
)

if no_check == "false":
bg.add_task(dns.background_check_new_domain, domain.name)
if "webmail" in domain.features:
return web_models.Domain.from_db(domain_db, domain.context_name)
else:
Expand Down
Loading

0 comments on commit 9237e68

Please sign in to comment.