diff --git a/central_backend/src/devices/devices_list.rs b/central_backend/src/devices/devices_list.rs index c884206..0cbaa65 100644 --- a/central_backend/src/devices/devices_list.rs +++ b/central_backend/src/devices/devices_list.rs @@ -101,6 +101,11 @@ impl DevicesList { self.0.clone().into_values().collect() } + /// Get the information about a single device + pub fn get_single(&self, id: &DeviceId) -> Option { + self.0.get(id).cloned() + } + /// Validate a device pub fn validate(&mut self, id: &DeviceId) -> anyhow::Result<()> { let dev = self diff --git a/central_backend/src/energy/energy_actor.rs b/central_backend/src/energy/energy_actor.rs index eb72d03..4c93672 100644 --- a/central_backend/src/energy/energy_actor.rs +++ b/central_backend/src/energy/energy_actor.rs @@ -137,3 +137,16 @@ impl Handler for EnergyActor { self.devices.full_list() } } + +/// Get the information about a single device +#[derive(Message)] +#[rtype(result = "Option")] +pub struct GetSingleDevice(pub DeviceId); + +impl Handler for EnergyActor { + type Result = Option; + + fn handle(&mut self, msg: GetSingleDevice, _ctx: &mut Context) -> Self::Result { + self.devices.get_single(&msg.0) + } +} diff --git a/central_backend/src/server/devices_api/mgmt_controller.rs b/central_backend/src/server/devices_api/mgmt_controller.rs index 648cbfe..f29092a 100644 --- a/central_backend/src/server/devices_api/mgmt_controller.rs +++ b/central_backend/src/server/devices_api/mgmt_controller.rs @@ -71,3 +71,34 @@ pub async fn enroll(req: web::Json, actor: WebEnergyActor) -> Htt Ok(HttpResponse::Accepted().json("Device successfully enrolled")) } + +#[derive(serde::Deserialize)] +pub struct EnrollmentStatusQuery { + id: DeviceId, +} + +#[derive(serde::Serialize)] +#[serde(tag = "status")] +enum EnrollmentDeviceStatus { + Unknown, + Pending, + Validated, +} + +/// Check device enrollment status +pub async fn enrollment_status( + query: web::Query, + actor: WebEnergyActor, +) -> HttpResult { + let dev = actor + .send(energy_actor::GetSingleDevice(query.id.clone())) + .await?; + + let status = match dev { + None => EnrollmentDeviceStatus::Unknown, + Some(d) if d.validated => EnrollmentDeviceStatus::Validated, + _ => EnrollmentDeviceStatus::Pending, + }; + + Ok(HttpResponse::Ok().json(status)) +} diff --git a/central_backend/src/server/servers.rs b/central_backend/src/server/servers.rs index ab2021c..515e3c0 100644 --- a/central_backend/src/server/servers.rs +++ b/central_backend/src/server/servers.rs @@ -156,7 +156,10 @@ pub async fn secure_server(energy_actor: EnergyActorAddr) -> anyhow::Result<()> "/devices_api/mgmt/enroll", web::post().to(mgmt_controller::enroll), ) - // TODO : check device status + .route( + "/devices_api/mgmt/enrollment_status", + web::get().to(mgmt_controller::enrollment_status), + ) }) .bind_openssl(&AppConfig::get().listen_address, builder)? .run() diff --git a/python_device/src/api.py b/python_device/src/api.py index 34370b9..80e6356 100644 --- a/python_device/src/api.py +++ b/python_device/src/api.py @@ -17,6 +17,20 @@ def get_root_ca() -> str: return res.text +def device_enrollment_status() -> str: + """ + Get current device enrollment status + """ + res = requests.get( + f"{args.secure_origin}/devices_api/mgmt/enrollment_status?id={args.dev_id}", + verify=args.root_ca_path, + ) + if res.status_code < 200 or res.status_code > 299: + print(res.text) + raise Exception(f"Failed to check enrollment with status {res.status_code}") + return res.json()["status"] + + def device_info(): """ Get device information to return with enrollment and sync requests diff --git a/python_device/src/args.py b/python_device/src/args.py index 2f6b277..579e920 100644 --- a/python_device/src/args.py +++ b/python_device/src/args.py @@ -19,8 +19,8 @@ args = parser.parse_args() args.secure_origin_path = os.path.join(args.storage, "SECURE_ORIGIN") args.root_ca_path = os.path.join(args.storage, "root_ca.crt") +args.dev_id_path = os.path.join(args.storage, "DEV_ID") args.dev_priv_key_path = os.path.join(args.storage, "dev.key") args.dev_csr_path = os.path.join(args.storage, "dev.csr") -args.dev_enroll_marker = os.path.join(args.storage, "ENROLL_SUBMITTED") args.dev_crt_path = os.path.join(args.storage, "dev.crt") args.relay_gpios_list = list(map(lambda x: int(x), args.relay_gpios.split(","))) diff --git a/python_device/src/main.py b/python_device/src/main.py index 9dd1c9a..0edf9e7 100644 --- a/python_device/src/main.py +++ b/python_device/src/main.py @@ -27,6 +27,15 @@ if not os.path.isfile(args.root_ca_path): with open(args.root_ca_path, "w") as f: f.write(origin) +print("Check device ID") +if not os.path.isfile(args.dev_id_path): + print("Generate device id...") + with open(args.dev_id_path, "w") as f: + f.write(f"PyDev {utils.rand_str(10)}") + +with open(args.dev_id_path, "r") as f: + args.dev_id = f.read() + print("Check private key") if not os.path.isfile(args.dev_priv_key_path): print("Generate private key...") @@ -39,19 +48,27 @@ if not os.path.isfile(args.dev_csr_path): print("Generate CSR...") with open(args.dev_priv_key_path, "r") as f: priv_key = "".join(f.readlines()) - csr = pki.gen_csr(priv_key=priv_key, cn=f"PyDev {utils.rand_str(10)}") + csr = pki.gen_csr(priv_key=priv_key, cn=args.dev_id) with open(args.dev_csr_path, "w") as f: f.write(csr) print("Check device enrollment...") -if not os.path.isfile(args.dev_enroll_marker): +status = api.device_enrollment_status() + +if status == "Unknown": + print("Device is unknown on the system, need to submit a CSR...") with open(args.dev_csr_path, "r") as f: csr = "".join(f.read()) print("Enrolling device...") crt = api.enroll_device(csr) - - with open(args.dev_enroll_marker, "w") as f: - f.write("submitted") + print("Done. Please accept the device on central system web UI") + exit(0) -# TODO : "intelligent" enrollment management (re-enroll if cancelled) \ No newline at end of file +if status == "Pending": + print( + "Device is enrolled, but not validated yet. Please accept the device on central system web UI" + ) + exit(0) + +print("Device is successfully enrolled!")