Start to implement devices enrollment

This commit is contained in:
Pierre HUBERT 2024-07-01 21:10:45 +02:00
parent 378c296e71
commit 9ba4aa5194
21 changed files with 267 additions and 16 deletions

View File

@ -598,9 +598,11 @@ dependencies = [
"openssl-sys", "openssl-sys",
"rand", "rand",
"reqwest", "reqwest",
"semver",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
"uuid",
] ]
[[package]] [[package]]
@ -1828,6 +1830,9 @@ name = "semver"
version = "1.0.23" version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde" name = "serde"
@ -2238,6 +2243,16 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439"
dependencies = [
"getrandom",
"serde",
]
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"
version = "0.2.15" version = "0.2.15"

View File

@ -27,3 +27,5 @@ actix-session = { version = "0.9.0", features = ["cookie-session"] }
actix-cors = "0.7.0" actix-cors = "0.7.0"
actix-remote-ip = "0.1.0" actix-remote-ip = "0.1.0"
futures-util = "0.3.30" futures-util = "0.3.30"
uuid = { version = "1.9.1", features = ["v4", "serde"] }
semver = { version = "1.0.23", features = ["serde"] }

View File

@ -1,3 +1,4 @@
use crate::devices::device::DeviceId;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@ -154,7 +155,7 @@ impl AppConfig {
/// Get PKI root CA cert path /// Get PKI root CA cert path
pub fn root_ca_cert_path(&self) -> PathBuf { pub fn root_ca_cert_path(&self) -> PathBuf {
self.pki_path().join("root_ca.pem") self.pki_path().join("root_ca.crt")
} }
/// Get PKI root CA CRL path /// Get PKI root CA CRL path
@ -169,7 +170,7 @@ impl AppConfig {
/// Get PKI web CA cert path /// Get PKI web CA cert path
pub fn web_ca_cert_path(&self) -> PathBuf { pub fn web_ca_cert_path(&self) -> PathBuf {
self.pki_path().join("web_ca.pem") self.pki_path().join("web_ca.crt")
} }
/// Get PKI web CA CRL path /// Get PKI web CA CRL path
@ -184,7 +185,7 @@ impl AppConfig {
/// Get PKI devices CA cert path /// Get PKI devices CA cert path
pub fn devices_ca_cert_path(&self) -> PathBuf { pub fn devices_ca_cert_path(&self) -> PathBuf {
self.pki_path().join("devices_ca.pem") self.pki_path().join("devices_ca.crt")
} }
/// Get PKI devices CA CRL path /// Get PKI devices CA CRL path
@ -199,13 +200,33 @@ impl AppConfig {
/// Get PKI server cert path /// Get PKI server cert path
pub fn server_cert_path(&self) -> PathBuf { pub fn server_cert_path(&self) -> PathBuf {
self.pki_path().join("server.pem") self.pki_path().join("server.crt")
} }
/// Get PKI server private key path /// Get PKI server private key path
pub fn server_priv_key_path(&self) -> PathBuf { pub fn server_priv_key_path(&self) -> PathBuf {
self.pki_path().join("server.key") self.pki_path().join("server.key")
} }
/// Get devices configuration storage path
pub fn devices_config_path(&self) -> PathBuf {
self.storage_path().join("devices")
}
/// Get device configuration path
pub fn device_config_path(&self, id: &DeviceId) -> PathBuf {
self.devices_config_path().join(format!("{}.conf", id.0))
}
/// Get device certificate path
pub fn device_cert_path(&self, id: &DeviceId) -> PathBuf {
self.devices_config_path().join(format!("{}.crt", id.0))
}
/// Get device CSR path
pub fn device_csr_path(&self, id: &DeviceId) -> PathBuf {
self.devices_config_path().join(format!("{}.csr", id.0))
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -0,0 +1,52 @@
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DeviceInfo {
reference: String,
version: semver::Version,
max_relays: usize,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq, Hash)]
pub struct DeviceId(pub String);
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Device {
/// The device ID
id: DeviceId,
/// Information about the device
device: DeviceInfo,
/// Name given to the device on the Web UI
name: String,
/// Description given to the device on the Web UI
description: String,
/// Specify whether the device is enabled or not
enabled: bool,
/// Specify whether the device has been validated or not
validated: bool,
/// Information about the relays handled by the device
relays: Vec<DeviceRelay>,
}
/// Structure that contains information about the minimal expected execution
/// time of a device
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DailyMinRuntime {
min_runtime: usize,
reset_time: usize,
catch_up_hours: Vec<usize>,
}
#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
pub struct DeviceRelayID(uuid::Uuid);
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DeviceRelay {
id: DeviceRelayID,
name: String,
enabled: bool,
priority: usize,
consumption: usize,
minimal_uptime: usize,
minimal_downtime: usize,
daily_runtime: Option<DailyMinRuntime>,
depends_on: Vec<DeviceRelay>,
}

View File

@ -0,0 +1,36 @@
use crate::app_config::AppConfig;
use crate::devices::device::{Device, DeviceId};
use std::collections::HashMap;
pub struct DevicesList(HashMap<DeviceId, Device>);
impl DevicesList {
/// Load the list of devices. This method should be called only once during the whole execution
/// of the program
pub fn load() -> anyhow::Result<Self> {
let mut list = Self(HashMap::new());
for f in std::fs::read_dir(AppConfig::get().devices_config_path())? {
let f = f?.file_name();
let f = f.to_string_lossy();
let dev_id = match f.strip_suffix(".conf") {
Some(s) => DeviceId(s.to_string()),
// This is not a device configuration file
None => continue,
};
let device_conf = std::fs::read(AppConfig::get().device_config_path(&dev_id))?;
list.0.insert(dev_id, serde_json::from_slice(&device_conf)?);
}
Ok(list)
}
/// Check if a device with a given id exists or not
pub fn exists(&self, id: &DeviceId) -> bool {
self.0.contains_key(id)
}
}

View File

@ -0,0 +1,2 @@
pub mod device;
pub mod devices_list;

View File

@ -1,16 +1,20 @@
use crate::constants; use crate::constants;
use crate::devices::device::DeviceId;
use crate::devices::devices_list::DevicesList;
use crate::energy::consumption; use crate::energy::consumption;
use crate::energy::consumption::EnergyConsumption; use crate::energy::consumption::EnergyConsumption;
use actix::prelude::*; use actix::prelude::*;
pub struct EnergyActor { pub struct EnergyActor {
curr_consumption: EnergyConsumption, curr_consumption: EnergyConsumption,
devices: DevicesList,
} }
impl EnergyActor { impl EnergyActor {
pub async fn new() -> anyhow::Result<Self> { pub async fn new() -> anyhow::Result<Self> {
Ok(Self { Ok(Self {
curr_consumption: consumption::get_curr_consumption().await?, curr_consumption: consumption::get_curr_consumption().await?,
devices: DevicesList::load()?,
}) })
} }
@ -62,3 +66,16 @@ impl Handler<GetCurrConsumption> for EnergyActor {
self.curr_consumption self.curr_consumption
} }
} }
/// Get current consumption
#[derive(Message)]
#[rtype(result = "bool")]
pub struct CheckDeviceExists(DeviceId);
impl Handler<CheckDeviceExists> for EnergyActor {
type Result = bool;
fn handle(&mut self, msg: CheckDeviceExists, _ctx: &mut Context<Self>) -> Self::Result {
self.devices.exists(&msg.0)
}
}

View File

@ -1,6 +1,7 @@
pub mod app_config; pub mod app_config;
pub mod constants; pub mod constants;
pub mod crypto; pub mod crypto;
pub mod devices;
pub mod energy; pub mod energy;
pub mod server; pub mod server;
pub mod utils; pub mod utils;

View File

@ -15,6 +15,7 @@ async fn main() -> std::io::Result<()> {
// Initialize storage // Initialize storage
create_directory_if_missing(AppConfig::get().pki_path()).unwrap(); create_directory_if_missing(AppConfig::get().pki_path()).unwrap();
create_directory_if_missing(AppConfig::get().devices_config_path()).unwrap();
// Initialize PKI // Initialize PKI
pki::initialize_root_ca().expect("Failed to initialize Root CA!"); pki::initialize_root_ca().expect("Failed to initialize Root CA!");

View File

@ -103,6 +103,12 @@ impl From<actix_identity::error::LoginError> for HttpErr {
} }
} }
impl From<openssl::error::ErrorStack> for HttpErr {
fn from(value: openssl::error::ErrorStack) -> Self {
HttpErr::Err(std::io::Error::new(ErrorKind::Other, value.to_string()).into())
}
}
impl From<HttpResponse> for HttpErr { impl From<HttpResponse> for HttpErr {
fn from(value: HttpResponse) -> Self { fn from(value: HttpResponse) -> Self {
HttpErr::HTTPResponse(value) HttpErr::HTTPResponse(value)

View File

@ -0,0 +1,32 @@
use crate::devices::device::DeviceInfo;
use crate::server::custom_error::HttpResult;
use actix_web::{web, HttpResponse};
use openssl::x509::X509Req;
#[derive(Debug, serde::Deserialize)]
pub struct EnrollRequest {
/// Device CSR
csr: String,
/// Associated device information
info: DeviceInfo,
}
/// Enroll a new device
pub async fn enroll(req: web::Json<EnrollRequest>) -> HttpResult {
let csr = match X509Req::from_pem(req.csr.as_bytes()) {
Ok(r) => r,
Err(e) => {
log::error!("Failed to parse given CSR! {e}");
return Ok(HttpResponse::BadRequest().json("Failed to parse given CSR!"));
}
};
if !csr.verify(csr.public_key()?.as_ref())? {
log::error!("Invalid CSR signature!");
return Ok(HttpResponse::BadRequest().json("Could not verify CSR signature!"));
}
println!("{:#?}", &req);
Ok(HttpResponse::Ok().json("go on"))
}

View File

@ -1 +1,2 @@
pub mod mgmt_controller;
pub mod utils_controller; pub mod utils_controller;

View File

@ -3,7 +3,7 @@ use crate::constants;
use crate::crypto::pki; use crate::crypto::pki;
use crate::energy::energy_actor::EnergyActorAddr; use crate::energy::energy_actor::EnergyActorAddr;
use crate::server::auth_middleware::AuthChecker; use crate::server::auth_middleware::AuthChecker;
use crate::server::devices_api::utils_controller; use crate::server::devices_api::{mgmt_controller, utils_controller};
use crate::server::unsecure_server::*; use crate::server::unsecure_server::*;
use crate::server::web_api::*; use crate::server::web_api::*;
use actix_cors::Cors; use actix_cors::Cors;
@ -136,6 +136,10 @@ pub async fn secure_server(energy_actor: EnergyActorAddr) -> anyhow::Result<()>
"/devices_api/utils/time", "/devices_api/utils/time",
web::get().to(utils_controller::curr_time), web::get().to(utils_controller::curr_time),
) )
.route(
"/devices_api/mgmt/enroll",
web::post().to(mgmt_controller::enroll),
)
}) })
.bind_openssl(&AppConfig::get().listen_address, builder)? .bind_openssl(&AppConfig::get().listen_address, builder)?
.run() .run()

View File

@ -12,7 +12,7 @@ pub async fn serve_pki_file(path: web::Path<ServeCRLPath>) -> HttpResult {
for f in std::fs::read_dir(AppConfig::get().pki_path())? { for f in std::fs::read_dir(AppConfig::get().pki_path())? {
let f = f?; let f = f?;
let file_name = f.file_name().to_string_lossy().to_string(); let file_name = f.file_name().to_string_lossy().to_string();
if !file_name.ends_with(".crl") && !file_name.ends_with(".pem") { if !file_name.ends_with(".crl") && !file_name.ends_with(".crt") {
continue; continue;
} }

View File

@ -1,5 +1,11 @@
# Python client # Python client
Reformat code:
```bash
black src/*.py
```
Run the client: Run the client:
```bash ```bash

View File

@ -1,5 +1,7 @@
import requests import requests
from src.args import args from src.args import args
import src.constants as constants
def get_secure_origin() -> str: def get_secure_origin() -> str:
res = requests.get(f"{args.unsecure_origin}/secure_origin") res = requests.get(f"{args.unsecure_origin}/secure_origin")
@ -7,8 +9,32 @@ def get_secure_origin() -> str:
raise Exception(f"Get secure origin failed with status {res.status_code}") raise Exception(f"Get secure origin failed with status {res.status_code}")
return res.text return res.text
def get_root_ca() -> str: def get_root_ca() -> str:
res = requests.get(f"{args.unsecure_origin}/pki/root_ca.pem") res = requests.get(f"{args.unsecure_origin}/pki/root_ca.crt")
if res.status_code < 200 or res.status_code > 299: if res.status_code < 200 or res.status_code > 299:
raise Exception(f"Get root CA failed with status {res.status_code}") raise Exception(f"Get root CA failed with status {res.status_code}")
return res.text return res.text
def device_info():
"""
Get device information to return with enrollment and sync requests
"""
return {
"reference": constants.DEV_REFERENCE,
"version": constants.DEV_VERSION,
"max_relays": len(args.relay_gpios_list),
}
def enroll_device(csr: str) -> str:
res = requests.post(
f"{args.secure_origin}/devices_api/mgmt/enroll",
json={"csr": csr, "info": device_info()},
verify=args.root_ca_path,
)
if res.status_code < 200 or res.status_code > 299:
print(res.text)
raise Exception(f"Enrollment failed with status {res.status_code}")
return res.text

View File

@ -1,15 +1,25 @@
import argparse import argparse
import os import os
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="SolarEnergy Python-based client")
description='SolarEnergy Python-based client')
parser.add_argument("--unsecure_origin", help="Change unsecure API origin", default="http://localhost:8080") parser.add_argument(
"--unsecure_origin",
help="Change unsecure API origin",
default="http://localhost:8080",
)
parser.add_argument("--storage", help="Change storage location", default="storage") parser.add_argument("--storage", help="Change storage location", default="storage")
parser.add_argument(
"--relay_gpios",
help="Comma-separated list of GPIO used to modify relays",
default="5,6,7",
)
args = parser.parse_args() args = parser.parse_args()
args.secure_origin_path = os.path.join(args.storage, "SECURE_ORIGIN") args.secure_origin_path = os.path.join(args.storage, "SECURE_ORIGIN")
args.root_ca_path = os.path.join(args.storage, "root_ca.pem") args.root_ca_path = os.path.join(args.storage, "root_ca.crt")
args.dev_priv_key_path = os.path.join(args.storage, "dev.key") args.dev_priv_key_path = os.path.join(args.storage, "dev.key")
args.dev_csr_path = os.path.join(args.storage, "dev.csr") args.dev_csr_path = os.path.join(args.storage, "dev.csr")
args.dev_crt_path = os.path.join(args.storage, "dev.crt")
args.relay_gpios_list = list(map(lambda x: int(x), args.relay_gpios.split(",")))

View File

@ -0,0 +1,5 @@
# Device reference. This value should never be changed
DEV_REFERENCE = "PyDev"
# Current device version. Must follow semver semantic
DEV_VERSION = "0.0.1"

View File

@ -21,7 +21,6 @@ with open(args.secure_origin_path, "r") as f:
print(f"Secure origin = {args.secure_origin}") print(f"Secure origin = {args.secure_origin}")
print("Check system root CA") print("Check system root CA")
if not os.path.isfile(args.root_ca_path): if not os.path.isfile(args.root_ca_path):
origin = api.get_root_ca() origin = api.get_root_ca()
@ -43,3 +42,12 @@ if not os.path.isfile(args.dev_csr_path):
csr = pki.gen_csr(priv_key=priv_key, cn=f"PyDev {utils.rand_str(10)}") csr = pki.gen_csr(priv_key=priv_key, cn=f"PyDev {utils.rand_str(10)}")
with open(args.dev_csr_path, "w") as f: with open(args.dev_csr_path, "w") as f:
f.write(csr) f.write(csr)
print("Check device enrollment...")
if not os.path.isfile(args.dev_crt_path):
with open(args.dev_csr_path, "r") as f:
csr = "".join(f.read())
print("Enrolling device...")
crt = api.enroll_device(csr)
print("res" + crt)

View File

@ -1,13 +1,16 @@
from OpenSSL import crypto from OpenSSL import crypto
def gen_priv_key(): def gen_priv_key():
key = crypto.PKey() key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048) key.generate_key(crypto.TYPE_RSA, 2048)
return crypto.dump_privatekey(crypto.FILETYPE_PEM, key).decode("utf-8") return crypto.dump_privatekey(crypto.FILETYPE_PEM, key).decode("utf-8")
def parse_priv_key(priv_key: str) -> crypto.PKey: def parse_priv_key(priv_key: str) -> crypto.PKey:
return crypto.load_privatekey(crypto.FILETYPE_PEM, priv_key) return crypto.load_privatekey(crypto.FILETYPE_PEM, priv_key)
def gen_csr(priv_key: str, cn: str) -> str: def gen_csr(priv_key: str, cn: str) -> str:
priv_key = parse_priv_key(priv_key) priv_key = parse_priv_key(priv_key)

View File

@ -1,5 +1,8 @@
import string import string
import random import random
def rand_str(len: int) -> str: def rand_str(len: int) -> str:
return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(len)) return "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(len)
)