You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airavata.apache.org by di...@apache.org on 2022/12/31 10:53:28 UTC
[airavata-mft] branch master updated: Auto deployment of EC2 transfer agents
This is an automated email from the ASF dual-hosted git repository.
dimuthuupe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata-mft.git
The following commit(s) were added to refs/heads/master by this push:
new a1f5432 Auto deployment of EC2 transfer agents
a1f5432 is described below
commit a1f5432d2ed72da205aedcf210dc1cd83df06f13
Author: Dimuthu Wannipurage <di...@gmail.com>
AuthorDate: Sat Dec 31 05:53:05 2022 -0500
Auto deployment of EC2 transfer agents
---
.../airavata/mft/api/handler/MFTApiHandler.java | 1 +
.../mft/controller/AgentTransferDispatcher.java | 149 +++++++++++++++-----
.../apache/airavata/mft/controller/AppConfig.java | 4 +-
.../mft/controller/spawner/CloudAgentSpawner.java | 7 +-
.../{AwsAgentSpawner.java => EC2AgentSpawner.java} | 150 +++++++++++++++++----
.../mft/controller/spawner/SSHProvider.java | 134 ++++++++++++++++++
.../mft/controller/spawner/SpawnerSelector.java | 2 +-
python-cli/mft_cli/mft_cli/main.py | 3 +-
8 files changed, 389 insertions(+), 61 deletions(-)
diff --git a/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java b/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java
index fd932c0..cc00835 100644
--- a/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java
+++ b/api/service/src/main/java/org/apache/airavata/mft/api/handler/MFTApiHandler.java
@@ -265,6 +265,7 @@ public class MFTApiHandler extends MFTTransferServiceGrpc.MFTTransferServiceImpl
} else if (!mainTransferStatus.isEmpty()){
stateBuilder.setState(mainTransferStatus.get(0).getState());
stateBuilder.setPercentage(0);
+ stateBuilder.setDescription(mainTransferStatus.get(0).getDescription());
responseObserver.onNext(stateBuilder.build());
responseObserver.onCompleted();
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java b/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java
index b5a5f1d..21cf35d 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/AgentTransferDispatcher.java
@@ -24,60 +24,134 @@ import org.apache.airavata.mft.api.service.EndpointPaths;
import org.apache.airavata.mft.api.service.TransferApiRequest;
import org.apache.airavata.mft.controller.spawner.CloudAgentSpawner;
import org.apache.airavata.mft.controller.spawner.SpawnerSelector;
+import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.Future;
+import java.util.concurrent.*;
import java.util.stream.Collectors;
public class AgentTransferDispatcher {
private static final Logger logger = LoggerFactory.getLogger(AgentTransferDispatcher.class);
+
+ //getId(transferRequest):Pair<TransferApiRequest, AgentTransferRequest.Builder>
+
private final Map<String, Pair<TransferApiRequest, AgentTransferRequest.Builder>> pendingTransferRequests = new ConcurrentHashMap<>();
+ private final Map<String, String> pendingTransferIds = new ConcurrentHashMap<>();
+ //getId(transferRequest):consulKey
+
private final Map<String, String> pendingTransferConsulKeys = new ConcurrentHashMap<>();
- private final Map<String, Future<String>> pendingAgentSpawners = new ConcurrentHashMap<>();
+
+ //getId(transferRequest):CloudAgentSpawner
+ private final Map<String, CloudAgentSpawner> pendingAgentSpawners = new ConcurrentHashMap<>();
+
+ // getId(transferRequest):Set(TransferId)
private final Map<String, Set<String>> runningAgentCache = new ConcurrentHashMap<>();
+ // AgentID:Spawner - Use this to keep track of agent spawners. This is required to terminate agent
+ private final Map<String, CloudAgentSpawner> agentSpawners = new ConcurrentHashMap<>();
+
+ private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+
+ // Temporarily store consul key until the optimizer spins up Agents. This will block the same pending transfer
+ // being handled twice
+ private final Set<String> optimizingConsulKeys = new ConcurrentSkipListSet<>();
+
@Autowired
private MFTConsulClient mftConsulClient;
+ public void init() {
+ scheduler.scheduleWithFixedDelay(() -> {
+ pendingAgentSpawners.forEach((key, spawner) -> {
+ if (spawner.getLaunchState().isDone()) {
+ String transferId = pendingTransferIds.get(key);
+ Pair<TransferApiRequest, AgentTransferRequest.Builder> transferRequests = pendingTransferRequests.get(key);
+ String consulKey = pendingTransferConsulKeys.get(key);
+
+ try {
+ String agentId = spawner.getLaunchState().get();
+ List<String> liveAgentIds = mftConsulClient.getLiveAgentIds();
+ if (liveAgentIds.stream().noneMatch(id -> id.equals(agentId))) {
+ throw new Exception("Agent was not registered even though the agent is up");
+ }
+
+ submitTransferToAgent(Collections.singletonList(agentId), transferId,
+ transferRequests.getLeft(), transferRequests.getRight(), consulKey);
+
+ // Use this to terminate agent in future
+ agentSpawners.put(agentId, spawner);
+
+ } catch (Exception e) {
+ logger.error("Failed to launch agent for key {}", key, e);
+ try {
+ mftConsulClient.saveTransferState(transferId, new TransferState()
+ .setUpdateTimeMils(System.currentTimeMillis())
+ .setState("FAILED").setPercentage(0)
+ .setPublisher("controller")
+ .setDescription("Failed to launch the agent. " + ExceptionUtils.getRootCauseMessage(e)));
+ } catch (Exception e2) {
+ logger.error("Failed to submit transfer fail error for transfer id {}", transferId, e2);
+ }
+
+ logger.info("Removing consul key {}", consulKey);
+ mftConsulClient.getKvClient().deleteKey(consulKey);
+ logger.info("Terminating the spawner");
+ spawner.terminate();
+
+ } finally {
+ pendingTransferIds.remove(key);
+ pendingTransferRequests.remove(key);
+ pendingAgentSpawners.remove(key);
+ pendingTransferConsulKeys.remove(key);
+ optimizingConsulKeys.remove(consulKey);
+ }
+ }
+ });
+ }, 3, 5, TimeUnit.SECONDS);
+ }
+
+
public void submitTransferToAgent(List<String> filteredAgents, String transferId,
TransferApiRequest transferRequest,
AgentTransferRequest.Builder agentTransferRequestTemplate, String consulKey)
throws Exception {
- if (filteredAgents.isEmpty()) {
+ try {
+ if (filteredAgents.isEmpty()) {
+ mftConsulClient.saveTransferState(transferId, new TransferState()
+ .setUpdateTimeMils(System.currentTimeMillis())
+ .setState("FAILED").setPercentage(0)
+ .setPublisher("controller")
+ .setDescription("No qualifying agent was found to orchestrate the transfer"));
+ return;
+ }
+
mftConsulClient.saveTransferState(transferId, new TransferState()
+ .setState("STARTING")
+ .setPercentage(0)
.setUpdateTimeMils(System.currentTimeMillis())
- .setState("FAILED").setPercentage(0)
.setPublisher("controller")
- .setDescription("No qualifying agent was found to orchestrate the transfer"));
- return;
- }
+ .setDescription("Initializing the transfer"));
- mftConsulClient.saveTransferState(transferId, new TransferState()
- .setState("STARTING")
- .setPercentage(0)
- .setUpdateTimeMils(System.currentTimeMillis())
- .setPublisher("controller")
- .setDescription("Initializing the transfer"));
+ AgentTransferRequest.Builder agentTransferRequest = agentTransferRequestTemplate.clone();
- AgentTransferRequest.Builder agentTransferRequest = agentTransferRequestTemplate.clone();
+ agentTransferRequest.setRequestId(UUID.randomUUID().toString());
+ for (EndpointPaths ep : transferRequest.getEndpointPathsList()) {
+ agentTransferRequest.addEndpointPaths(org.apache.airavata.mft.agent.stub.EndpointPaths.newBuilder()
+ .setSourcePath(ep.getSourcePath())
+ .setDestinationPath(ep.getDestinationPath()).buildPartial());
+ }
- agentTransferRequest.setRequestId(UUID.randomUUID().toString());
- for (EndpointPaths ep : transferRequest.getEndpointPathsList()) {
- agentTransferRequest.addEndpointPaths(org.apache.airavata.mft.agent.stub.EndpointPaths.newBuilder()
- .setSourcePath(ep.getSourcePath())
- .setDestinationPath(ep.getDestinationPath()).buildPartial());
+ // TODO use a better way to select the right agent
+ mftConsulClient.commandTransferToAgent(filteredAgents.get(0), transferId, agentTransferRequest.build());
+ mftConsulClient.markTransferAsProcessed(transferId, transferRequest);
+ logger.info("Marked transfer {} as processed", transferId);
+ } finally {
+ mftConsulClient.getKvClient().deleteKey(consulKey);
}
-
- // TODO use a better way to select the right agent
- mftConsulClient.commandTransferToAgent(filteredAgents.get(0), transferId, agentTransferRequest.build());
- mftConsulClient.markTransferAsProcessed(transferId, transferRequest);
- mftConsulClient.getKvClient().deleteKey(consulKey);
}
public void handleTransferRequest(String transferId,
@@ -85,6 +159,12 @@ public class AgentTransferDispatcher {
AgentTransferRequest.Builder agentTransferRequestTemplate,
String consulKey) throws Exception{
+ if (optimizingConsulKeys.contains(consulKey)) {
+ logger.info("Ignoring handling transfer id {} as it is already in optimizing stage", transferId);
+ return;
+ }
+
+ logger.info("Handling transfer id {} with consul key {}", transferId, consulKey);
List<String> liveAgentIds = mftConsulClient.getLiveAgentIds();
Map<String, Integer> targetAgentsMap = transferRequest.getTargetAgentsMap();
@@ -115,22 +195,27 @@ public class AgentTransferDispatcher {
if (sourceSpawner.isPresent()) {
logger.info("Launching {} spawner in source side for transfer {}",
sourceSpawner.get().getClass().getName(), transferId);
- Future<String> launchFuture = sourceSpawner.get().launch();
- pendingAgentSpawners.put(getId(transferRequest, true), launchFuture);
+
+ sourceSpawner.get().launch();
+ pendingAgentSpawners.put(getId(transferRequest, true), sourceSpawner.get());
pendingTransferRequests.put(getId(transferRequest, true),
Pair.of(transferRequest, agentTransferRequestTemplate));
+ pendingTransferIds.put(getId(transferRequest, true), transferId);
pendingTransferConsulKeys.put(getId(transferRequest, true), consulKey);
-
+ optimizingConsulKeys.add(consulKey);
+ return;
} else if (destSpawner.isPresent()) {
logger.info("Launching {} spawner in destination side for transfer {}",
destSpawner.get().getClass().getName(), transferId);
- Future<String> launchFuture = destSpawner.get().launch();
- pendingAgentSpawners.put(getId(transferRequest, false), launchFuture);
+ destSpawner.get().launch();
+ pendingAgentSpawners.put(getId(transferRequest, false), destSpawner.get());
pendingTransferRequests.put(getId(transferRequest, false),
Pair.of(transferRequest, agentTransferRequestTemplate));
+ pendingTransferIds.put(getId(transferRequest, false), transferId);
pendingTransferConsulKeys.put(getId(transferRequest, false), consulKey);
-
+ optimizingConsulKeys.add(consulKey);
+ return;
} else {
logger.warn("No optimizing path is available. Moving user provided agents");
submitTransferToAgent(userProvidedAgents, transferId,
@@ -165,8 +250,6 @@ public class AgentTransferDispatcher {
agentTransferRequestTemplate,
consulKey);
}
-
- logger.info("Marked transfer {} as processed", transferId);
}
private String getId(TransferApiRequest transferRequest, boolean isSource) {
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java b/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java
index c962bbb..ca31060 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/AppConfig.java
@@ -67,6 +67,8 @@ public class AppConfig {
@Bean
public AgentTransferDispatcher pathOptimizer() {
- return new AgentTransferDispatcher();
+ AgentTransferDispatcher agentTransferDispatcher = new AgentTransferDispatcher();
+ agentTransferDispatcher.init();
+ return agentTransferDispatcher;
}
}
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java
index 8d73a73..8b435ab 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/CloudAgentSpawner.java
@@ -31,6 +31,9 @@ public abstract class CloudAgentSpawner {
this.storageWrapper = storageWrapper;
}
- public abstract Future<String> launch();
- public abstract Future<Boolean> terminate();
+ public abstract void launch();
+ public abstract Future<String> getLaunchState();
+ public abstract void terminate();
+
+ public abstract Future<Boolean> getTerminateState();
}
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AwsAgentSpawner.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
similarity index 52%
rename from controller/src/main/java/org/apache/airavata/mft/controller/spawner/AwsAgentSpawner.java
rename to controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
index c992e01..890e395 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AwsAgentSpawner.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
@@ -25,54 +25,50 @@ import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
import com.amazonaws.services.ec2.model.*;
import org.apache.airavata.mft.agent.stub.SecretWrapper;
import org.apache.airavata.mft.agent.stub.StorageWrapper;
-import org.apache.airavata.mft.controller.AgentTransferDispatcher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
+import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Stream;
-public class AwsAgentSpawner extends CloudAgentSpawner {
+public class EC2AgentSpawner extends CloudAgentSpawner {
- private static final Logger logger = LoggerFactory.getLogger(AwsAgentSpawner.class);
+ private static final Logger logger = LoggerFactory.getLogger(EC2AgentSpawner.class);
private final ExecutorService executor = Executors.newSingleThreadExecutor();
- public AwsAgentSpawner(StorageWrapper storageWrapper, SecretWrapper secretWrapper) {
+ private String instanceId;
+ private CountDownLatch portForwardLock;
+ Future<String> launchFuture;
+ Future<Boolean> terminateFuture;
+
+ public EC2AgentSpawner(StorageWrapper storageWrapper, SecretWrapper secretWrapper) {
super(storageWrapper, secretWrapper);
}
@Override
- public Future<String> launch() {
+ public void launch() {
- return executor.submit( () -> {
- String imageId = "ami-0574da719dca65348";
+ launchFuture = executor.submit( () -> {
+ String imageId = "ami-0ecc74eca1d66d8a6"; // Ubuntu base image
String keyNamePrefix = "mft-aws-agent-key-";
- String region = storageWrapper.getS3().getRegion();
String secGroupName = "MFTAgentSecurityGroup";
String agentId = UUID.randomUUID().toString();
+ String systemUser = "ubuntu";
String mftKeyDir = System.getProperty("user.home") + File.separator + ".mft" + File.separator + "keys";
String accessKey = secretWrapper.getS3().getAccessKey();
String secretKey = secretWrapper.getS3().getSecretKey();
-
- String cloudInit =
- "#cloud-config\n" +
- "\n" +
- "runcmd:\n" +
- " - apt install -y openjdk-11-jre-headless\n" +
- " - apt install -y unzip\n" +
- " - wget https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip\n" +
- " - unzip MFT-Agent-0.01-bin.zip -d /home/ubuntu/\n" +
- " - sed -ir \"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\" /home/ubuntu/MFT-Agent-0.01/conf/application.properties\n" +
- " - chown -R ubuntu:ubuntu /home/ubuntu/MFT-Agent-0.01/\n";
+ String region = storageWrapper.getS3().getRegion();
try {
BasicAWSCredentials awsCreds = new BasicAWSCredentials(accessKey, secretKey);
@@ -159,22 +155,130 @@ public class AwsAgentSpawner extends CloudAgentSpawner {
new TagSpecification().withResourceType(ResourceType.Instance)
.withTags(new Tag().withKey("Type").withValue("MFT-Agent"),
new Tag().withKey("AgentId").withValue(agentId)))
- .withUserData(Base64.getEncoder().encodeToString(cloudInit.getBytes()))
.withSecurityGroups(secGroupName);
logger.info("Launching the EC2 VM to start Agent {}", agentId);
RunInstancesResult result = amazonEC2.runInstances(runInstancesRequest);
+ instanceId = result.getReservation().getInstances().get(0).getInstanceId();
+
+ try {
+ DescribeInstancesRequest describeInstancesRequest = new DescribeInstancesRequest();
+ describeInstancesRequest.setInstanceIds(Collections.singletonList(instanceId));
+
+ InstanceState instanceState = null;
+ String publicIpAddress = null;
+
+ logger.info("Waiting until instance {} is ready", instanceId);
+
+ for (int i = 0; i < 30; i++) {
+ DescribeInstancesResult describeInstancesResult = amazonEC2.describeInstances(describeInstancesRequest);
+ Instance instance = describeInstancesResult.getReservations().get(0).getInstances().get(0);
+ instanceState = instance.getState();
+ publicIpAddress = instance.getPublicIpAddress();
+
+ logger.info("Instance state {}, public ip {}", instanceState.getName(), publicIpAddress);
+
+ if (instanceState.getName().equals("running") && publicIpAddress != null) {
+ break;
+ }
+ Thread.sleep(2000);
+ }
+
+ logger.info("Waiting 30 seconds until the ssh interface comes up in instance {}", instanceId);
+ Thread.sleep(30000);
+ if ("running".equals(instanceState.getName()) && publicIpAddress != null) {
+ SSHProvider portForwardAgent = new SSHProvider();
+ portForwardAgent.initConnection(publicIpAddress, 22,
+ Path.of(mftKeyDir, keyName).toAbsolutePath().toString(), systemUser);
+ logger.info("Created SSH Connection. Installing dependencies...");
+
+ int exeCode = portForwardAgent.runCommand("sudo apt install -y openjdk-11-jre-headless");
+ if (exeCode != 0)
+ throw new IOException("Failed to install jdk on new VM");
+ exeCode = portForwardAgent.runCommand("sudo apt install -y unzip");
+ if (exeCode != 0)
+ throw new IOException("Failed to install unzip on new VM");
+ exeCode = portForwardAgent.runCommand("wget https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip");
+ if (exeCode != 0)
+ throw new IOException("Failed to download mft distribution");
+ exeCode = portForwardAgent.runCommand("unzip MFT-Agent-0.01-bin.zip");
+ if (exeCode != 0)
+ throw new IOException("Failed to unzip mft distribution");
+
+ exeCode = portForwardAgent.runCommand("sed -ir \"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\" /home/ubuntu/MFT-Agent-0.01/conf/application.properties");
+ if (exeCode != 0)
+ throw new IOException("Failed to update agent id in config file");
+
+ portForwardLock = new CountDownLatch(1);
+ CountDownLatch portForwardPendingLock = portForwardAgent.createSshPortForward(8500, portForwardLock);
+
+ logger.info("Waiting until the port forward is setup");
+ portForwardPendingLock.await();
+
+ exeCode = portForwardAgent.runCommand("sh MFT-Agent-0.01/bin/agent-daemon.sh start");
+ if (exeCode != 0)
+ throw new IOException("Failed to start the MFT Agent");
+
+ // Waiting 10 seconds to start the Agent
+ Thread.sleep(10000);
+
+ } else {
+ logger.info("Instance {} was not setup properly", instanceId);
+ throw new Exception("Instance " + instanceId + " was not setup properly");
+ }
+ } catch (Exception e) {
+ logger.error("Failed preparing instance {}. Deleting the instance", instanceId);
+ TerminateInstancesRequest terminateInstancesRequest = new TerminateInstancesRequest();
+ terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
+ amazonEC2.terminateInstances(terminateInstancesRequest);
+ throw e;
+ }
+
return agentId;
} catch (Exception e) {
- throw new RuntimeException("Failed to spin up the AWS Agent", e);
+ logger.error("Failed to spin up the EC2 Agent", e);
+ throw new RuntimeException("Failed to spin up the EC2 Agent", e);
}
});
}
@Override
- public Future<Boolean> terminate() {
- return null;
+ public void terminate() {
+
+ terminateFuture = executor.submit(() -> {
+ if (instanceId != null) {
+ String accessKey = secretWrapper.getS3().getAccessKey();
+ String secretKey = secretWrapper.getS3().getSecretKey();
+ String region = storageWrapper.getS3().getRegion();
+
+ BasicAWSCredentials awsCreds = new BasicAWSCredentials(accessKey, secretKey);
+
+ AmazonEC2 amazonEC2 = AmazonEC2ClientBuilder.standard().withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(
+ "https://ec2." + region + ".amazonaws.com", region))
+ .withCredentials(new AWSStaticCredentialsProvider(awsCreds))
+ .build();
+
+ if (portForwardLock != null) {
+ portForwardLock.countDown();
+ }
+
+ TerminateInstancesRequest terminateInstancesRequest = new TerminateInstancesRequest();
+ terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
+ amazonEC2.terminateInstances(terminateInstancesRequest);
+ }
+ return true;
+ });
+ }
+
+ @Override
+ public Future<String> getLaunchState() {
+ return launchFuture;
+ }
+
+ @Override
+ public Future<Boolean> getTerminateState() {
+ return terminateFuture;
}
}
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
new file mode 100644
index 0000000..bbd0166
--- /dev/null
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.airavata.mft.controller.spawner;
+
+import net.schmizz.keepalive.KeepAliveProvider;
+import net.schmizz.sshj.DefaultConfig;
+import net.schmizz.sshj.SSHClient;
+import net.schmizz.sshj.connection.channel.direct.Session;
+import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
+import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
+import net.schmizz.sshj.transport.verification.HostKeyVerifier;
+import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
+import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
+import net.schmizz.sshj.userauth.method.AuthMethod;
+import net.schmizz.sshj.userauth.method.AuthPublickey;
+import net.schmizz.sshj.userauth.method.ChallengeResponseProvider;
+import net.schmizz.sshj.userauth.password.Resource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.PublicKey;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+
+public class SSHProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(SSHProvider.class);
+
+ private SSHClient client;
+
+ public void initConnection(String hostName, int port, String keyPath, String user) throws IOException {
+ DefaultConfig defaultConfig = new DefaultConfig();
+ defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
+
+ client = new SSHClient(defaultConfig);
+ client.addHostKeyVerifier(new HostKeyVerifier() {
+ @Override
+ public boolean verify(String s, int i, PublicKey publicKey) {
+ return true;
+ }
+
+ @Override
+ public List<String> findExistingAlgorithms(String s, int i) {
+ return null;
+ }
+ });
+
+ KeyProvider keyProvider = client.loadKeys(keyPath);
+
+ final List<AuthMethod> am = new LinkedList<>();
+ am.add(new AuthPublickey(keyProvider));
+
+ am.add(new AuthKeyboardInteractive(new ChallengeResponseProvider() {
+ @Override
+ public List<String> getSubmethods() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public void init(Resource resource, String name, String instruction) {
+ }
+
+ @Override
+ public char[] getResponse(String prompt, boolean echo) {
+ return new char[0];
+ }
+
+ @Override
+ public boolean shouldRetry() {
+ return false;
+ }
+ }));
+
+ client.connect(hostName, port);
+ client.auth(user, am);
+ }
+
+ public int runCommand(String command) throws IOException {
+ Session session = this.client.startSession();
+ logger.info("Running command {}", command);
+ Session.Command execResult = session.exec(command);
+ String stdOut = new String(execResult.getInputStream().readAllBytes());
+ String stdErr = new String(execResult.getErrorStream().readAllBytes());
+ logger.info("Std out: {}", stdOut);
+ logger.info("Std err: {}", stdErr);
+ logger.info("Exit code: {}", execResult.getExitStatus());
+ session.close();
+ return execResult.getExitStatus();
+ }
+
+ public CountDownLatch createSshPortForward(int localPort, CountDownLatch portForwardHoldLock) throws IOException, InterruptedException {
+
+ CountDownLatch portForwardCompleteLock = new CountDownLatch(1);
+ new Thread(() -> {
+ String consulHost = "localhost";
+
+ try {
+ client.getRemotePortForwarder().bind(
+ new RemotePortForwarder.Forward(localPort),
+ new SocketForwardingConnectListener(new InetSocketAddress(consulHost, localPort)));
+
+ portForwardCompleteLock.countDown();
+ logger.info("Created port forward to port " + localPort);
+ portForwardHoldLock.await();
+
+ logger.info("Releasing the remote port forward");
+ client.getRemotePortForwarder().cancel(new RemotePortForwarder.Forward(localPort));
+
+ } catch (Exception e) {
+ logger.error("Failed to create the remote port forward for port {}", localPort, e);
+ }
+ }).start();
+ return portForwardCompleteLock;
+ }
+}
diff --git a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java
index 72cfad3..1c86446 100644
--- a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java
+++ b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SpawnerSelector.java
@@ -28,7 +28,7 @@ public class SpawnerSelector {
switch (storageWrapper.getStorageCase()) {
case S3:
if (storageWrapper.getS3().getEndpoint().endsWith("amazonaws.com")) {
- return Optional.of(new AwsAgentSpawner(storageWrapper, secretWrapper));
+ return Optional.of(new EC2AgentSpawner(storageWrapper, secretWrapper));
}
break;
}
diff --git a/python-cli/mft_cli/mft_cli/main.py b/python-cli/mft_cli/mft_cli/main.py
index baf59d3..1accfd6 100644
--- a/python-cli/mft_cli/mft_cli/main.py
+++ b/python-cli/mft_cli/mft_cli/main.py
@@ -101,7 +101,8 @@ def copy(source, destination):
transfer_request = MFTTransferApi_pb2.TransferApiRequest(sourceStorageId = source_storage_id,
sourceSecretId = source_secret_id,
destinationStorageId = dest_storage_id,
- destinationSecretId = dest_secret_id)
+ destinationSecretId = dest_secret_id,
+ optimizeTransferPath = True)
if (source_metadata.WhichOneof('metadata') == 'directory') :
if (destination[-1] != "/"):