hydro_deploy/
azure.rs

1use std::any::Any;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::sync::{Arc, Mutex, OnceLock};
5
6use anyhow::Result;
7use async_trait::async_trait;
8use nanoid::nanoid;
9use serde_json::json;
10
11use super::terraform::{TERRAFORM_ALPHABET, TerraformOutput, TerraformProvider};
12use super::{ClientStrategy, Host, HostTargetType, LaunchedHost, ResourceBatch, ResourceResult};
13use crate::ssh::LaunchedSshHost;
14use crate::{BaseServerStrategy, HostStrategyGetter, PortNetworkHint};
15
16pub struct LaunchedVirtualMachine {
17    resource_result: Arc<ResourceResult>,
18    user: String,
19    pub internal_ip: String,
20    pub external_ip: Option<String>,
21}
22
23impl LaunchedSshHost for LaunchedVirtualMachine {
24    fn get_external_ip(&self) -> Option<String> {
25        self.external_ip.clone()
26    }
27
28    fn get_internal_ip(&self) -> String {
29        self.internal_ip.clone()
30    }
31
32    fn get_cloud_provider(&self) -> String {
33        "Azure".to_string()
34    }
35
36    fn resource_result(&self) -> &Arc<ResourceResult> {
37        &self.resource_result
38    }
39
40    fn ssh_user(&self) -> &str {
41        self.user.as_str()
42    }
43}
44
45pub struct AzureHost {
46    /// ID from [`crate::Deployment::add_host`].
47    id: usize,
48
49    project: String,
50    os_type: String, // linux or windows
51    machine_size: String,
52    image: Option<HashMap<String, String>>,
53    region: String,
54    user: Option<String>,
55    pub launched: OnceLock<Arc<LaunchedVirtualMachine>>, // TODO(mingwei): fix pub
56    external_ports: Mutex<Vec<u16>>,
57}
58
59impl AzureHost {
60    pub fn new(
61        id: usize,
62        project: String,
63        os_type: String, // linux or windows
64        machine_size: String,
65        image: Option<HashMap<String, String>>,
66        region: String,
67        user: Option<String>,
68    ) -> Self {
69        Self {
70            id,
71            project,
72            os_type,
73            machine_size,
74            image,
75            region,
76            user,
77            launched: OnceLock::new(),
78            external_ports: Mutex::new(Vec::new()),
79        }
80    }
81}
82
83impl Debug for AzureHost {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.write_fmt(format_args!("AzureHost({})", self.id))
86    }
87}
88
89#[async_trait]
90impl Host for AzureHost {
91    fn target_type(&self) -> HostTargetType {
92        HostTargetType::Linux
93    }
94
95    fn request_port_base(&self, bind_type: &BaseServerStrategy) {
96        match bind_type {
97            BaseServerStrategy::UnixSocket => {}
98            BaseServerStrategy::InternalTcpPort(_) => {}
99            BaseServerStrategy::ExternalTcpPort(port) => {
100                let mut external_ports = self.external_ports.lock().unwrap();
101                if !external_ports.contains(port) {
102                    if self.launched.get().is_some() {
103                        todo!("Cannot adjust firewall after host has been launched");
104                    }
105                    external_ports.push(*port);
106                }
107            }
108        }
109    }
110
111    fn request_custom_binary(&self) {
112        self.request_port_base(&BaseServerStrategy::ExternalTcpPort(22));
113    }
114
115    fn id(&self) -> usize {
116        self.id
117    }
118
119    fn collect_resources(&self, resource_batch: &mut ResourceBatch) {
120        if self.launched.get().is_some() {
121            return;
122        }
123
124        let project = self.project.as_str();
125
126        // first, we import the providers we need
127        resource_batch
128            .terraform
129            .terraform
130            .required_providers
131            .insert(
132                "azurerm".to_string(),
133                TerraformProvider {
134                    source: "hashicorp/azurerm".to_string(),
135                    version: "3.67.0".to_string(),
136                },
137            );
138
139        resource_batch
140            .terraform
141            .terraform
142            .required_providers
143            .insert(
144                "local".to_string(),
145                TerraformProvider {
146                    source: "hashicorp/local".to_string(),
147                    version: "2.3.0".to_string(),
148                },
149            );
150
151        resource_batch
152            .terraform
153            .terraform
154            .required_providers
155            .insert(
156                "tls".to_string(),
157                TerraformProvider {
158                    source: "hashicorp/tls".to_string(),
159                    version: "4.0.4".to_string(),
160                },
161            );
162
163        // we use a single SSH key for all VMs
164        resource_batch
165            .terraform
166            .resource
167            .entry("tls_private_key".to_string())
168            .or_default()
169            .insert(
170                "vm_instance_ssh_key".to_string(),
171                json!({
172                    "algorithm": "RSA",
173                    "rsa_bits": 4096
174                }),
175            );
176
177        resource_batch
178            .terraform
179            .resource
180            .entry("local_file".to_string())
181            .or_default()
182            .insert(
183                "vm_instance_ssh_key_pem".to_string(),
184                json!({
185                    "content": "${tls_private_key.vm_instance_ssh_key.private_key_pem}",
186                    "filename": ".ssh/vm_instance_ssh_key_pem",
187                    "file_permission": "0600"
188                }),
189            );
190
191        let vm_key = format!("vm-instance-{}", self.id);
192        let vm_name = format!("hydro-vm-instance-{}", nanoid!(8, &TERRAFORM_ALPHABET));
193
194        // Handle provider configuration
195        resource_batch.terraform.provider.insert(
196            "azurerm".to_string(),
197            json!({
198                "skip_provider_registration": "true",
199                "features": {},
200            }),
201        );
202
203        // Handle resources
204        resource_batch
205            .terraform
206            .resource
207            .entry("azurerm_resource_group".to_string())
208            .or_default()
209            .insert(
210                vm_key.to_string(),
211                json!({
212                    "name": project,
213                    "location": self.region.clone(),
214                }),
215            );
216
217        resource_batch
218            .terraform
219            .resource
220            .entry("azurerm_virtual_network".to_string())
221            .or_default()
222            .insert(
223                vm_key.to_string(),
224                json!({
225                    "name": format!("{vm_key}-network"),
226                    "address_space": ["10.0.0.0/16"],
227                    "location": self.region.clone(),
228                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}")
229                }),
230            );
231
232        resource_batch
233            .terraform
234            .resource
235            .entry("azurerm_subnet".to_string())
236            .or_default()
237            .insert(
238                vm_key.to_string(),
239                json!({
240                    "name": "internal",
241                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}"),
242                    "virtual_network_name": format!("${{azurerm_virtual_network.{vm_key}.name}}"),
243                    "address_prefixes": ["10.0.2.0/24"]
244                }),
245            );
246
247        resource_batch
248            .terraform
249            .resource
250            .entry("azurerm_public_ip".to_string())
251            .or_default()
252            .insert(
253                vm_key.to_string(),
254                json!({
255                    "name": "hydropubip",
256                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}"),
257                    "location": format!("${{azurerm_resource_group.{vm_key}.location}}"),
258                    "allocation_method": "Static",
259                }),
260            );
261
262        resource_batch
263            .terraform
264            .resource
265            .entry("azurerm_network_interface".to_string())
266            .or_default()
267            .insert(
268                vm_key.to_string(),
269                json!({
270                    "name": format!("{vm_key}-nic"),
271                    "location": format!("${{azurerm_resource_group.{vm_key}.location}}"),
272                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}"),
273                    "ip_configuration": {
274                        "name": "internal",
275                        "subnet_id": format!("${{azurerm_subnet.{vm_key}.id}}"),
276                        "private_ip_address_allocation": "Dynamic",
277                        "public_ip_address_id": format!("${{azurerm_public_ip.{vm_key}.id}}"),
278                    }
279                }),
280            );
281
282        // Define network security rules - for now, accept all connections
283        resource_batch
284            .terraform
285            .resource
286            .entry("azurerm_network_security_group".to_string())
287            .or_default()
288            .insert(
289                vm_key.to_string(),
290                json!({
291                    "name": "primary_security_group",
292                    "location": format!("${{azurerm_resource_group.{vm_key}.location}}"),
293                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}"),
294                }),
295            );
296
297        resource_batch
298            .terraform
299            .resource
300            .entry("azurerm_network_security_rule".to_string())
301            .or_default()
302            .insert(
303                vm_key.to_string(),
304                json!({
305                    "name": "allowall",
306                    "priority": 100,
307                    "direction": "Inbound",
308                    "access": "Allow",
309                    "protocol": "Tcp",
310                    "source_port_range": "*",
311                    "destination_port_range": "*",
312                    "source_address_prefix": "*",
313                    "destination_address_prefix": "*",
314                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}"),
315                    "network_security_group_name": format!("${{azurerm_network_security_group.{vm_key}.name}}"),
316                })
317            );
318
319        resource_batch
320            .terraform
321            .resource
322            .entry("azurerm_subnet_network_security_group_association".to_string())
323            .or_default()
324            .insert(
325                vm_key.to_string(),
326                json!({
327                    "subnet_id": format!("${{azurerm_subnet.{vm_key}.id}}"),
328                    "network_security_group_id": format!("${{azurerm_network_security_group.{vm_key}.id}}"),
329                })
330            );
331
332        let user = self.user.as_ref().cloned().unwrap_or("hydro".to_string());
333        let os_type = format!("azurerm_{}_virtual_machine", self.os_type.clone());
334        let image = self.image.as_ref().cloned().unwrap_or(HashMap::from([
335            ("publisher".to_string(), "Canonical".to_string()),
336            (
337                "offer".to_string(),
338                "0001-com-ubuntu-server-jammy".to_string(),
339            ),
340            ("sku".to_string(), "22_04-lts".to_string()),
341            ("version".to_string(), "latest".to_string()),
342        ]));
343
344        resource_batch
345            .terraform
346            .resource
347            .entry(os_type.clone())
348            .or_default()
349            .insert(
350                vm_key.clone(),
351                json!({
352                    "name": vm_name,
353                    "resource_group_name": format!("${{azurerm_resource_group.{vm_key}.name}}"),
354                    "location": format!("${{azurerm_resource_group.{vm_key}.location}}"),
355                    "size": self.machine_size.clone(),
356                    "network_interface_ids": [format!("${{azurerm_network_interface.{vm_key}.id}}")],
357                    "admin_ssh_key": {
358                        "username": user,
359                        "public_key": "${tls_private_key.vm_instance_ssh_key.public_key_openssh}",
360                    },
361                    "admin_username": user,
362                    "os_disk": {
363                        "caching": "ReadWrite",
364                        "storage_account_type": "Standard_LRS",
365                    },
366                    "source_image_reference": image,
367                }),
368            );
369
370        resource_batch.terraform.output.insert(
371            format!("{vm_key}-public-ip"),
372            TerraformOutput {
373                value: format!("${{azurerm_public_ip.{vm_key}.ip_address}}"),
374            },
375        );
376
377        resource_batch.terraform.output.insert(
378            format!("{vm_key}-internal-ip"),
379            TerraformOutput {
380                value: format!("${{azurerm_network_interface.{vm_key}.private_ip_address}}"),
381            },
382        );
383    }
384
385    fn launched(&self) -> Option<Arc<dyn LaunchedHost>> {
386        self.launched
387            .get()
388            .map(|a| a.clone() as Arc<dyn LaunchedHost>)
389    }
390
391    fn provision(&self, resource_result: &Arc<ResourceResult>) -> Arc<dyn LaunchedHost> {
392        self.launched
393            .get_or_init(|| {
394                let id = self.id;
395
396                let internal_ip = resource_result
397                    .terraform
398                    .outputs
399                    .get(&format!("vm-instance-{id}-internal-ip"))
400                    .unwrap()
401                    .value
402                    .clone();
403
404                let external_ip = resource_result
405                    .terraform
406                    .outputs
407                    .get(&format!("vm-instance-{id}-public-ip"))
408                    .map(|v| v.value.clone());
409
410                Arc::new(LaunchedVirtualMachine {
411                    resource_result: resource_result.clone(),
412                    user: self.user.as_ref().cloned().unwrap_or("hydro".to_string()),
413                    internal_ip,
414                    external_ip,
415                })
416            })
417            .clone()
418    }
419
420    fn strategy_as_server<'a>(
421        &'a self,
422        client_host: &dyn Host,
423        network_hint: PortNetworkHint,
424    ) -> Result<(ClientStrategy<'a>, HostStrategyGetter)> {
425        if matches!(network_hint, PortNetworkHint::Auto)
426            && client_host.can_connect_to(ClientStrategy::UnixSocket(self.id))
427        {
428            Ok((
429                ClientStrategy::UnixSocket(self.id),
430                Box::new(|_| BaseServerStrategy::UnixSocket),
431            ))
432        } else if matches!(
433            network_hint,
434            PortNetworkHint::Auto | PortNetworkHint::TcpPort(_)
435        ) && client_host.can_connect_to(ClientStrategy::InternalTcpPort(self))
436        {
437            Ok((
438                ClientStrategy::InternalTcpPort(self),
439                Box::new(move |_| {
440                    BaseServerStrategy::InternalTcpPort(match network_hint {
441                        PortNetworkHint::Auto => None,
442                        PortNetworkHint::TcpPort(port) => port,
443                    })
444                }),
445            ))
446        } else if matches!(network_hint, PortNetworkHint::Auto)
447            && client_host.can_connect_to(ClientStrategy::ForwardedTcpPort(self))
448        {
449            Ok((
450                ClientStrategy::ForwardedTcpPort(self),
451                Box::new(|me| {
452                    me.downcast_ref::<AzureHost>()
453                        .unwrap()
454                        .request_port_base(&BaseServerStrategy::ExternalTcpPort(22)); // needed to forward
455                    BaseServerStrategy::InternalTcpPort(None)
456                }),
457            ))
458        } else {
459            anyhow::bail!("Could not find a strategy to connect to Azure instance")
460        }
461    }
462
463    fn can_connect_to(&self, typ: ClientStrategy) -> bool {
464        match typ {
465            ClientStrategy::UnixSocket(id) => {
466                #[cfg(unix)]
467                {
468                    self.id == id
469                }
470
471                #[cfg(not(unix))]
472                {
473                    let _ = id;
474                    false
475                }
476            }
477            ClientStrategy::InternalTcpPort(target_host) => {
478                if let Some(provider_target) = <dyn Any>::downcast_ref::<AzureHost>(target_host) {
479                    self.project == provider_target.project
480                } else {
481                    false
482                }
483            }
484            ClientStrategy::ForwardedTcpPort(_) => false,
485        }
486    }
487}