source: icGREP/icgrep-devel/icgrep/editd/EditdCudaDriver.h @ 5212

Last change on this file since 5212 was 5212, checked in by lindanl, 3 years ago

editd for GPU.

File size: 3.8 KB
Line 
1#include <string>
2#include <iostream>
3#include <fstream>
4#include <sys/stat.h>
5#include <fcntl.h>
6#include <unistd.h>
7#include <cassert>
8#include "cuda.h"
9
10#define GROUPTHREADS 64
11#define GROUPBLOCKS 64
12
13void checkCudaErrors(CUresult err) {
14  assert(err == CUDA_SUCCESS);
15}
16
17/// main - Program entry point
18ulong * RunPTX(std::string PTXFilename, char * fileBuffer, ulong filesize, const char * patternStr, unsigned patternLen) {
19 
20  CUdevice    device;
21  CUmodule    cudaModule;
22  CUcontext   context;
23  CUfunction  function;
24  int         devCount;
25
26  // CUDA initialization
27  checkCudaErrors(cuInit(0));
28  checkCudaErrors(cuDeviceGetCount(&devCount));
29  checkCudaErrors(cuDeviceGet(&device, 0));
30
31  char name[128];
32  checkCudaErrors(cuDeviceGetName(name, 128, device));
33  // std::cout << "Using CUDA Device [0]: " << name << "\n";
34
35  int devMajor, devMinor;
36  checkCudaErrors(cuDeviceComputeCapability(&devMajor, &devMinor, device));
37  // std::cout << "Device Compute Capability: " << devMajor << "." << devMinor << "\n";
38  if (devMajor < 2) {
39    std::cerr << "ERROR: Device 0 is not SM 2.0 or greater\n";
40    exit(-1);
41  }
42
43  std::ifstream t(PTXFilename);
44  if (!t.is_open()) {
45    std::cerr << "Error: cannot open " << PTXFilename << " for processing. Skipped.\n";
46    exit(-1);
47  }
48 
49  std::string ptx_str((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
50
51  // Create driver context
52  checkCudaErrors(cuCtxCreate(&context, 0, device));
53
54  // Create module for object
55  checkCudaErrors(cuModuleLoadDataEx(&cudaModule, ptx_str.c_str(), 0, 0, 0));
56
57  // Get kernel function
58  checkCudaErrors(cuModuleGetFunction(&function, cudaModule, "GPU_Main"));
59
60  // Device data
61  CUdeviceptr devBufferInput;
62  CUdeviceptr devInputSize;
63  CUdeviceptr devPatterns;
64  CUdeviceptr devBufferOutput;
65  CUdeviceptr devStrides;
66
67  int strideSize = GROUPTHREADS * sizeof(ulong) * 4;
68  int strides = filesize/(strideSize * 2) + 1;
69  int bufferSize = strides * strideSize;
70  int outputSize = sizeof(ulong) * GROUPTHREADS * strides * 3 * GROUPBLOCKS;
71
72  checkCudaErrors(cuMemAlloc(&devBufferInput, bufferSize));
73  checkCudaErrors(cuMemAlloc(&devInputSize, sizeof(ulong)));
74  checkCudaErrors(cuMemAlloc(&devPatterns, patternLen));
75  checkCudaErrors(cuMemAlloc(&devBufferOutput, outputSize));
76  // checkCudaErrors(cuMemsetD8(devBufferOutput, 0, outputSize));
77  checkCudaErrors(cuMemAlloc(&devStrides, sizeof(int)));
78
79  //Copy from host to device
80  checkCudaErrors(cuMemcpyHtoD(devBufferInput, fileBuffer, bufferSize));
81  checkCudaErrors(cuMemcpyHtoD(devInputSize, &filesize, sizeof(ulong)));
82  checkCudaErrors(cuMemcpyHtoD(devPatterns, patternStr, patternLen));
83  checkCudaErrors(cuMemcpyHtoD(devStrides, &strides, sizeof(int)));
84
85  unsigned blockSizeX = GROUPTHREADS;
86  unsigned blockSizeY = 1;
87  unsigned blockSizeZ = 1;
88  unsigned gridSizeX  = GROUPBLOCKS;
89  unsigned gridSizeY  = 1;
90  unsigned gridSizeZ  = 1;
91
92  // Kernel parameters
93  void *KernelParams[] = { &devBufferInput, &devInputSize, &devPatterns, &devBufferOutput, &devStrides};
94
95  // std::cout << "Launching kernel\n";
96
97  // Kernel launch
98  checkCudaErrors(cuLaunchKernel(function, gridSizeX, gridSizeY, gridSizeZ,
99                                 blockSizeX, blockSizeY, blockSizeZ,
100                                 0, NULL, KernelParams, NULL));
101  // std::cout << "kernel success.\n";
102  // Retrieve device data
103
104  ulong * matchRslt = (ulong *) malloc(outputSize);
105  checkCudaErrors(cuMemcpyDtoH(matchRslt, devBufferOutput, outputSize));
106
107
108  // Clean-up
109  checkCudaErrors(cuMemFree(devBufferInput));
110  checkCudaErrors(cuMemFree(devInputSize));
111  checkCudaErrors(cuMemFree(devBufferOutput));
112  checkCudaErrors(cuMemFree(devPatterns));
113  checkCudaErrors(cuMemFree(devStrides));
114  checkCudaErrors(cuModuleUnload(cudaModule));
115  checkCudaErrors(cuCtxDestroy(context));
116
117  return matchRslt;
118}
Note: See TracBrowser for help on using the repository browser.