Start to implement devices enrollment

This commit is contained in:
2024-07-01 21:10:45 +02:00
parent 378c296e71
commit 9ba4aa5194
21 changed files with 267 additions and 16 deletions

View File

@ -1,5 +1,11 @@
# Python client
Reformat code:
```bash
black src/*.py
```
Run the client:
```bash

View File

@ -1,5 +1,7 @@
import requests
from src.args import args
import src.constants as constants
def get_secure_origin() -> str:
res = requests.get(f"{args.unsecure_origin}/secure_origin")
@ -7,8 +9,32 @@ def get_secure_origin() -> str:
raise Exception(f"Get secure origin failed with status {res.status_code}")
return res.text
def get_root_ca() -> str:
res = requests.get(f"{args.unsecure_origin}/pki/root_ca.pem")
res = requests.get(f"{args.unsecure_origin}/pki/root_ca.crt")
if res.status_code < 200 or res.status_code > 299:
raise Exception(f"Get root CA failed with status {res.status_code}")
return res.text
def device_info():
"""
Get device information to return with enrollment and sync requests
"""
return {
"reference": constants.DEV_REFERENCE,
"version": constants.DEV_VERSION,
"max_relays": len(args.relay_gpios_list),
}
def enroll_device(csr: str) -> str:
res = requests.post(
f"{args.secure_origin}/devices_api/mgmt/enroll",
json={"csr": csr, "info": device_info()},
verify=args.root_ca_path,
)
if res.status_code < 200 or res.status_code > 299:
print(res.text)
raise Exception(f"Enrollment failed with status {res.status_code}")
return res.text

View File

@ -1,15 +1,25 @@
import argparse
import os
parser = argparse.ArgumentParser(
description='SolarEnergy Python-based client')
parser = argparse.ArgumentParser(description="SolarEnergy Python-based client")
parser.add_argument("--unsecure_origin", help="Change unsecure API origin", default="http://localhost:8080")
parser.add_argument(
"--unsecure_origin",
help="Change unsecure API origin",
default="http://localhost:8080",
)
parser.add_argument("--storage", help="Change storage location", default="storage")
parser.add_argument(
"--relay_gpios",
help="Comma-separated list of GPIO used to modify relays",
default="5,6,7",
)
args = parser.parse_args()
args.secure_origin_path = os.path.join(args.storage, "SECURE_ORIGIN")
args.root_ca_path = os.path.join(args.storage, "root_ca.pem")
args.root_ca_path = os.path.join(args.storage, "root_ca.crt")
args.dev_priv_key_path = os.path.join(args.storage, "dev.key")
args.dev_csr_path = os.path.join(args.storage, "dev.csr")
args.dev_csr_path = os.path.join(args.storage, "dev.csr")
args.dev_crt_path = os.path.join(args.storage, "dev.crt")
args.relay_gpios_list = list(map(lambda x: int(x), args.relay_gpios.split(",")))

View File

@ -0,0 +1,5 @@
# Device reference. This value should never be changed
DEV_REFERENCE = "PyDev"
# Current device version. Must follow semver semantic
DEV_VERSION = "0.0.1"

View File

@ -21,7 +21,6 @@ with open(args.secure_origin_path, "r") as f:
print(f"Secure origin = {args.secure_origin}")
print("Check system root CA")
if not os.path.isfile(args.root_ca_path):
origin = api.get_root_ca()
@ -43,3 +42,12 @@ if not os.path.isfile(args.dev_csr_path):
csr = pki.gen_csr(priv_key=priv_key, cn=f"PyDev {utils.rand_str(10)}")
with open(args.dev_csr_path, "w") as f:
f.write(csr)
print("Check device enrollment...")
if not os.path.isfile(args.dev_crt_path):
with open(args.dev_csr_path, "r") as f:
csr = "".join(f.read())
print("Enrolling device...")
crt = api.enroll_device(csr)
print("res" + crt)

View File

@ -1,13 +1,16 @@
from OpenSSL import crypto
def gen_priv_key():
key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)
return crypto.dump_privatekey(crypto.FILETYPE_PEM, key).decode("utf-8")
def parse_priv_key(priv_key: str) -> crypto.PKey:
return crypto.load_privatekey(crypto.FILETYPE_PEM, priv_key)
def gen_csr(priv_key: str, cn: str) -> str:
priv_key = parse_priv_key(priv_key)
@ -15,5 +18,5 @@ def gen_csr(priv_key: str, cn: str) -> str:
req.get_subject().CN = cn
req.set_pubkey(priv_key)
req.sign(priv_key, "sha256")
return crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode("utf-8")

View File

@ -1,5 +1,8 @@
import string
import random
def rand_str(len: int) -> str:
return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(len))
return "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(len)
)