scala 소스 코드 리뷰

2015-05-14 12:23

최근 스터디에서 scala를 주제로 학습하고 있다. 단순히 scala 책만 보려니 재미 없어서 생각하는 프로그래밍 책을 읽으면서 알고리즘 공부도 좀 하면서 scala 언어도 배워가고 있다.

요구사항은 다음과 같다.

최대 40억 개의 32비트 정수가 랜덤한 순서로 들어있는 순차적 파일이 주어졌다. 이 때 이 파일에 포함되지 않은 임의의 32비트 정수 하나를 찾아라. (적어도 하나는 없을 것임이 분명하다. - 이유는?) 메모리는 수백 바이트 밖에 없고, 너댓 개의 외부 임시 파일을 사용할 수 있는 상황이라면 어떻게 해결할 수 있을까?

뭐 혼자 문제를 풀기는 힘들어 책의 힌트에서 아이디어를 얻어 scala로 구현을 시작했다. 일단 시작 단계는 동작하도록 구현하는데 집중한 코드는 다음과 같이 완전 지저분하다.

package pearls.essay2

import scala.io.Source
import pearls.support.Utils
import java.io.File
import java.io.FileWriter
import java.io.PrintWriter
import java.nio.file.Files

object MissFinder {
  private val prefix = "resources/pearls/essay2/"
  val lowerFileName = prefix + "lower.txt"
  val higherFileName = prefix + "higher.txt"
  private val tempFileName = prefix + "tempdata.txt"
    
  def getCenter(start: Int, end: Int) = {
    (start + end)/2
  }
  
  def getValue(fileName: String): Int = {
  	val lines = Source.fromFile(fileName).getLines()
  	if (lines.isEmpty) 
  		0
  	else 
  		lines.next().toInt
  }
  
  def getMissedValue(): Int = {
  	getMissedValue(getValue(lowerFileName), getValue(higherFileName))
  }
  
  def getMissedValue(lowerValue: Int, higherValue: Int): Int = {
  	if (higherValue > 0) {
  		higherValue - 1
  	} else {
  		lowerValue + 1
  	}
  }
  
  def findNo(fileName: String): Int = {
    def findMissedSector(start: Int, end: Int, missedFileName: String): Int = {
      println("start : " + start + " end : " + end)
      
      if (end - start == 1 || end - start == 0) {
        return getMissedValue()
      }
        
      val center = getCenter(start, end)
      var lower = 0
      var higher = 0
      
      Files.copy(new File(missedFileName).toPath(), new File(tempFileName).toPath())
      
      val lowerPW = new PrintWriter(new File(lowerFileName))
      val higherPW = new PrintWriter(new File(higherFileName))
      val linesSource = Source.fromFile(tempFileName)
      val lines = linesSource.getLines
      lines.foreach(line => {
        if (line.toInt > center) {
          higher += 1
          higherPW.println(line)
        } else {
          lower += 1
          lowerPW.println(line)
        }
      })

      lowerPW.close
      higherPW.close
      linesSource.close
      Files.delete(new File(tempFileName).toPath())

      println("center : " + center + " higher : " + higher + " lower : " + lower)
      if (end - center == higher)
        findMissedSector(start, center, lowerFileName)
      else
        findMissedSector(center, end, higherFileName)
    }
    
    findMissedSector(0, 1000, prefix + fileName)
  }

  def main(args: Array[String]): Unit = {
    println(findNo("missfiledata.txt"))
  }
}

위와 같이 구현한 후 지금까지 내가 알고 있는 scala 지식 범위 내에서 리팩토링 시작해 다음 단계까지 진행했다.

object Utils {
  def withPrintWriter(fileName: String)(op: PrintWriter => Unit) {
    withPrintWriter(new File(fileName))(op)
  }
  
  def withPrintWriter(f: File)(op: PrintWriter => Unit) {
    val p = new PrintWriter(f)
    try { op(p) } finally { p.close() }
  }

  def withPrintWriter2(fileName1: String, fileName2: String)(op: (PrintWriter, PrintWriter) => Unit) {
    withPrintWriter2(new File(fileName1), new File(fileName2))(op)
  }  
  
  def withPrintWriter2(f1: File, f2: File)(op: (PrintWriter, PrintWriter) => Unit) {
    val p1 = new PrintWriter(f1)
    val p2 = new PrintWriter(f2)
    try { op(p1, p2) } finally { p1.close; p2.close }
  }
  
  def withFileWriter(f: File)(op: FileWriter => Unit) {
    val fw = new FileWriter(f, true)
    try { op(fw) } finally { fw.close() }
  }
  
  def withFileLines(fileName: String)(op: Iterator[String] => Unit) {
    val source = Source.fromFile(fileName)
    try { op(source.getLines()) } finally { source.close }
  }
}
package pearls.essay2

import scala.io.Source
import pearls.support.Utils._

import java.io.File
import java.io.FileWriter
import java.io.PrintWriter
import java.nio.file.Files

object MissFinder {
  private val prefix = "resources/pearls/essay2/"
  private val tempFileName = prefix + "tempdata.txt"
  val lowerFileName = prefix + "lower.txt"
  val higherFileName = prefix + "higher.txt"

  def getCenter(start: Int, end: Int) = {
    (start + end) / 2
  }

  def getValue(fileName: String): Int = {
    val lines = Source.fromFile(fileName).getLines()
    if (lines.isEmpty)
      return 0
    lines.next().toInt
  }

  def getMissedValue(): Int = {
    getMissedValue(getValue(lowerFileName), getValue(higherFileName))
  }

  def getMissedValue(lowerValue: Int, higherValue: Int): Int = {
    if (higherValue > 0)
      return higherValue - 1
    lowerValue + 1
  }
  
  def findNo(fileName: String): Int = {
    def copyToWorkingFromOriginal(originalFileName: String) {
      if (new File(tempFileName).exists)
        Files.delete(new File(tempFileName).toPath())
      Files.copy(new File(originalFileName).toPath(), new File(tempFileName).toPath())
    }
    
    def divideHighAndLow(center: Int): Tuple2[Int, Int] = {
      var lower, higher = 0
      
      def plusOneAndPrint(lowerFile: PrintWriter, higherFile: PrintWriter, line: String) {
        if (line.toInt > center) { higher += 1; higherFile.println(line) } 
        else {lower += 1; lowerFile.println(line)} 
      }
      
      withFileLines(tempFileName) {
        lines => {
          withPrintWriter2(lowerFileName, higherFileName) {
            (lowerFile, higherFile) => { lines.foreach { line => plusOneAndPrint(lowerFile, higherFile, line) }}
          }          
        }
      }

      println(" lower : " + lower + ", center : " + center + ", higher : " + higher)
      (lower, higher)
    }
    
    def findMissedSector(start: Int, end: Int, originalFileName: String): Int = {
      if (end - start == 1 || end - start == 0) {
        return getMissedValue()
      }

      copyToWorkingFromOriginal(originalFileName)
      
      val center = getCenter(start, end)
      val lowAndHigh = divideHighAndLow(center)
      
      if (end - center == lowAndHigh._2)
        findMissedSector(start, center, lowerFileName)
      else
        findMissedSector(center, end, higherFileName)
    }

    findMissedSector(0, 1000, prefix + fileName)
  }

  def main(args: Array[String]): Unit = {
    println(findNo("missfiledata.txt"))
  }
}

물론 아직도 부족한 부분이 많아 보이고 복잡하다. 내가 집중한 부분은 파일에서 데이터를 읽어 두 개의 파일로 나누는 부분에 대한 리팩토링 작업을 연습해 봤다. 그 코드만 살펴보면 다음과 같다.

    def divideHighAndLow(center: Int): Tuple2[Int, Int] = {
      var lower, higher = 0
      
      def plusOneAndPrint(lowerFile: PrintWriter, higherFile: PrintWriter, line: String) {
        if (line.toInt > center) { higher += 1; higherFile.println(line) } 
        else {lower += 1; lowerFile.println(line)} 
      }
      
      withFileLines(tempFileName) {
        lines => {
          withPrintWriter2(lowerFileName, higherFileName) {
            (lowerFile, higherFile) => { lines.foreach { line => plusOneAndPrint(lowerFile, higherFile, line) }}
          }          
        }
      }

      println(" lower : " + lower + ", center : " + center + ", higher : " + higher)
      (lower, higher)
    }

위 소스 코드에서 lower와 higher에 +1하는 방식으로 구현하고 있는데 val을 쓰면서 처리할 수 있는 방법은 없을까? lower와 higher 값을 가지는 객체를 만들까도 생각해 봤다.

scala 나름 연습해 볼 수록 재미있는 언어라는 생각이 든다. 근데 너무 자유도가 높아서 슬슬 짜증이 나려고도 한다.

6개의 의견 from SLiPP

2015-05-14 13:30

코딩 스타일에 대해서 몇 가지 말씀을 드리자면..

1) 상수 사용:

private val prefix = "resources/pearls/essay2/"

여기서 prefix는 상수라는걸 (물론 val이면 상수지 뭐.. 하실 수도 있지만, 계산 도중에 사용하는 불변값 val과 프로그램에서 상수로 명확히 사용하는 상수는 조금 다르다고 봅니다) 밝히기 위해 대문자로 쓰시는게 좋을 것 같습니다.

(예)

private val Prefix = "resources/pearls/essay2/"

2) 짧은 메소드/함수에서 중괄호 사용:

def getCenter(start: Int, end: Int) = {
    (start + end)/2
}

한줄짜리 간단한 코드는 {}를 없애고 한줄로 쓰시는게 간결하죠.

(예)

def getCenter(start: Int, end: Int) = (start + end)/2
def getMissedValue() = getMissedValue(getValue(lowerFileName), getValue(higherFileName))

3) return 사용은 안 하는게 좋습니다. 뭐.. 억지로 부자연스럽게 바꿀 필요는 없겠지요.

4) getMissedValue()가 하는 일은 결국 getMissedValue(getValue(lowerFileName), getValue(higherFileName))를 호출한는건데, 그냥 중간단계를 생략하고 바로 호출하시는게 좋을듯.

5) findNo()안에 findMissedSector()를 넣어둘 이유가 없습니다. 그냥 findNo()를 없애고 메인에서 바로 findMissedSector(0, 1000, prefix + fileName)를 호출하시는게 좋을듯...

6) findMissedSector가 꼬리재귀이기 때문에 tailrec을 달아서 루프로 최적화하도록 힌트를 주시면 좋을듯.

@tailrec def findMissedSector(start: Int, end: Int, missedFileName: String):Int = ...

마지막으로.. "위 소스 코드에서 lower와 higher에 +1하는 방식으로 구현하고 있는데 val을 쓰면서 처리할 수 있는 방법은 없을까?" 에 대한 답은 fold를 사용...(돌려보진 않았습니다.. 문법 오류가 있을수도..)

def divideHighAndLow(center: Int): Tuple2[Int, Int] = {
  withFileLines(tempFileName) {
    lines => {
      withPrintWriter2(lowerFileName, higherFileName) {
        (lowerFile, higherFile) => {
          def plusOneAndPrint(hl:(Int, Int), line: String) =  // lexical scoping 사용해서 higherFile/lowerFile을 인자로 넘기지 않아도 됨.
            if (line.toInt > center) { 
              higherFile.println(line); (hl._1,hl._2+1) 
            }  else { 
              lowerFile.println(line); (hl._1+1,hl._2) 
            } 
	  lines.foldLeft(lower, higher) plusOneAndPrint // foldLeft와 ( 사이에 공백 있으면 안됨.
        }
      }
    }          
  }
  
  println(" lower : " + lower + ", center : " + center + ", higher : " + higher)
  (lower, higher)
}
2015-05-14 14:43

@enshahar 구체적인 코드 리뷰 감사합니다. 새로운 언어를 배우거나 지식을 학습할 때 피드백을 받으면 느끼는 바가 많더라고요. 부족한 코드에 대해 좋은 피드백 남겨 주셔서 학습하는데 많은 도움이 되겠네요. 주신 의견 반영하도록 연습할께요.

foldLeft는 제가 직접 해볼께요. foldLeft, foldRight function은 알고 있는데 아직 어떤 곳에 활용할 수 있을지 바로 바로 적용이 쉽지 않네요.

2015-05-14 15:23

아.. 깜빡하고 타입을 안 맞췄네요. fold를 사용하고 부수효과를 사용하지 않으려면 loan pattern에서 반환값을 내놓게 타입파라미터화 시키고...

def withPrintWriter[T](f: File)(op: PrintWriter => T) {
    val p = new PrintWriter(f)
    try { op(p) } finally { p.close() }
  }

divideHighAndLow 안에서는 이 반환값을 사용하게 조금 변경하면 될것 같습니다.

2015-05-14 16:02

@자바지기 Scala에서 foldLeft, foldRight, reduceLeft, reduceRight 등의 함수들은 이쪽 계열 함수들 중 끝판왕입니다. 안되는게 없지요. Ruby나 Smalltalk의 inject와 같은 역할입니다.

참고로 Scala에서 reduce와 fold의 차이는 리턴값의 타입 차이입니다. List[Int]를 리듀스해서 리턴값도 Int 타입이면 reduceLeft, List[Int]를 리듀스해서 리턴값이 Int 아닌 다른 타입이면 foldLeft 쓰면 됩니다.

이는 함수 시그너쳐를 보면 알 수 있는데,

def foldLeft [B] (z: B)(f: (B, A) => B): B

def reduceLeft [B >: A] (f: (B, A) => B): B

[B]와 [B >: A] 차이입니다.

2015-05-14 16:35

첨언하자면, reduce와 foldLeft의 차이는 타입 차이뿐이 아닙니다.

reduce는 evaluation order를 보장해 주지 않고, foldLeft는 왼쪽에서 오른쪽으로 evaluation 한다는걸 보장해 줍니다. 그래서 foldLeft에 넘기는 함수엔 결합성에 대한 언급이 없지만, reduce에 넘기는 함수에는 결합법칙이 성립하는 연산이어야 한다는 조건이 붙어있습니다.

물론 결과적으로 구현이 같을 수는 있지만, 구현에 따라서는 reduce는 병렬 계산을 적용해 계산한다던지 하는 장난도 가능하기 때문에 결합법칙에 대해 약간의 주의가 필요합니다.

2015-05-14 23:11

@enshahar 피드백 주신 내용 반영해서 다음과 같이 적용하니 잘 되네요. 물론 적용 과정에서 삽질도 많이 했네요.

    def divideHighAndLow(center: Int): (Int, Int) = {
      withFileLines[(Int, Int)](tempFileName) {
        lines =>
          withPrintWriter2[(Int, Int)](lowerFileName, higherFileName) {
            (lowerFile, higherFile) =>
              def plusOneAndPrint(hl: (Int, Int), line: String) = {
                if (line.toInt > center) { higherFile.println(line); (hl._1, hl._2 + 1) }
                else { lowerFile.println(line); (hl._1 + 1, hl._2) }
              }

              lines.foldLeft(0, 0)(plusOneAndPrint)
          }
      }
    }

피드백 후 리팩토링한 전체 코드 올렸으니 참고하세요.

의견 추가하기