use axum::{ extract::{Path, Request, State}, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, routing::{get, post}, Json, Router, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; use uuid::Uuid; use crate::{ models::error::ApiError, services::host_provisioning::{HostProvisioningService, ProvisionedLicense}, AppState, }; pub fn router() -> Router> { Router::new() .route("/provision", post(provision_license)) .route("/licenses", get(list_host_licenses)) .route("/billing/:month", get(get_billing_report)) .layer(middleware::from_fn(host_auth_middleware)) } /// Host API key authentication middleware async fn host_auth_middleware( State(state): State>, mut req: Request, next: Next, ) -> Result { // Extract API key from Authorization header (Bearer token) let auth_header = req .headers() .get("Authorization") .and_then(|h| h.to_str().ok()) .ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?; if !auth_header.starts_with("Bearer ") { return Err(ApiError::Unauthorized("Invalid Authorization format. Use: Bearer ".to_string())); } let api_key = &auth_header[7..]; // Skip "Bearer " // Authenticate host let service = HostProvisioningService::new(state.db.clone()); let host_id = service .authenticate_host(api_key) .await .map_err(|_| ApiError::Unauthorized("Invalid or inactive API key".to_string()))?; // Store host_id in request extensions for handlers to access req.extensions_mut().insert(HostContext { host_id }); Ok(next.run(req).await) } /// Host context extracted from API key #[derive(Clone)] struct HostContext { host_id: Uuid, } // ============================================================================ // BULK LICENSE PROVISIONING // ============================================================================ #[derive(Deserialize)] struct ProvisionRequest { server_id: String, // Host's internal server identifier (e.g., "rust-nyc-01") hostname: Option, // Optional display name customer_email: String, // End customer email } #[derive(Serialize)] struct ProvisionResponse { license_key: String, companion_token: String, plugin_download_url: String, subdomain: String, panel_url: String, } /// Provision a new license for a hosting customer (B2B bulk provisioning) async fn provision_license( State(state): State>, axum::extract::Extension(ctx): axum::extract::Extension, Json(req): Json, ) -> Result { let service = HostProvisioningService::new(state.db.clone()); // Use hostname if provided, otherwise use server_id let server_identifier = req.hostname.as_ref().unwrap_or(&req.server_id); let provisioned = service .provision_license(ctx.host_id, server_identifier, &req.customer_email) .await?; Ok(Json(ProvisionResponse { license_key: provisioned.license_key.clone(), companion_token: provisioned.companion_token, plugin_download_url: provisioned.plugin_download_url, subdomain: provisioned.subdomain.clone(), panel_url: format!("https://panel.corrosionmgmt.com/login?license={}", provisioned.license_key), })) } // ============================================================================ // HOST LICENSE MANAGEMENT // ============================================================================ #[derive(Serialize)] struct HostLicenseInfo { license_key: String, server_name: String, customer_email: String, subdomain: String, active: bool, last_seen_at: Option>, provisioned_at: chrono::DateTime, } /// List all licenses provisioned by this host async fn list_host_licenses( State(state): State>, axum::extract::Extension(ctx): axum::extract::Extension, ) -> Result { let licenses = sqlx::query!( "SELECT l.license_key, l.server_name, l.subdomain, (l.status = 'active') as \"active!\", hl.customer_email, hl.last_seen_at, hl.provisioned_at FROM host_licenses hl INNER JOIN licenses l ON l.id = hl.license_id WHERE hl.host_id = $1 ORDER BY hl.provisioned_at DESC", ctx.host_id ) .fetch_all(&state.db) .await?; let result: Vec = licenses .into_iter() .map(|row| HostLicenseInfo { license_key: row.license_key, server_name: row.server_name.unwrap_or_default(), customer_email: row.customer_email.unwrap_or_default(), subdomain: row.subdomain.unwrap_or_default(), active: row.active, last_seen_at: row.last_seen_at, provisioned_at: row.provisioned_at.unwrap(), }) .collect(); Ok(Json(result)) } // ============================================================================ // BILLING REPORTS // ============================================================================ #[derive(Serialize)] struct BillingReport { month: String, active_license_count: i32, wholesale_rate_usd: rust_decimal::Decimal, total_amount_usd: rust_decimal::Decimal, licenses: Vec, } #[derive(Serialize)] struct BillingLicenseEntry { license_key: String, server_name: String, customer_email: String, active: bool, last_seen_at: Option>, } /// Get billing report for a specific month (format: YYYY-MM, e.g., "2026-02") async fn get_billing_report( State(state): State>, axum::extract::Extension(ctx): axum::extract::Extension, Path(month): axum::extract::Path, ) -> Result { // Parse month (format: YYYY-MM) let billing_month = chrono::NaiveDate::parse_from_str(&format!("{}-01", month), "%Y-%m-%d") .map_err(|_| ApiError::BadRequest("Invalid month format. Use YYYY-MM (e.g., 2026-02)".to_string()))?; // Get billing record let record = sqlx::query!( "SELECT active_license_count, wholesale_rate_usd, total_amount_usd FROM host_billing_records WHERE host_id = $1 AND billing_month = $2", ctx.host_id, billing_month ) .fetch_optional(&state.db) .await?; let (active_count, wholesale_rate, total_amount) = if let Some(rec) = record { ( rec.active_license_count, rec.wholesale_rate_usd, rec.total_amount_usd, ) } else { // No billing record yet — generate on-the-fly let service = HostProvisioningService::new(state.db.clone()); let count = service.get_active_license_count(ctx.host_id).await?; let rate = rust_decimal::Decimal::from(6); // Default $6/server let total = rate * rust_decimal::Decimal::from(count); (count as i32, rate, total) }; // Get license details let licenses = sqlx::query!( "SELECT l.license_key, l.server_name, (l.status = 'active') as \"active!\", hl.customer_email, hl.last_seen_at FROM host_licenses hl INNER JOIN licenses l ON l.id = hl.license_id WHERE hl.host_id = $1 ORDER BY l.server_name", ctx.host_id ) .fetch_all(&state.db) .await?; let license_entries: Vec = licenses .into_iter() .map(|row| BillingLicenseEntry { license_key: row.license_key, server_name: row.server_name.unwrap_or_default(), customer_email: row.customer_email.unwrap_or_default(), active: row.active, last_seen_at: row.last_seen_at, }) .collect(); Ok(Json(BillingReport { month: month.clone(), active_license_count: active_count, wholesale_rate_usd: wholesale_rate, total_amount_usd: total_amount, licenses: license_entries, })) }